First XGBoost Model

import pickle
from numpy import loadtxt
from xgboost import XGBClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
data_path = '/home/naji/Desktop/github-repos/machine-learning/nbs/0-datasets/'
pima_file = 'pima-indians-diabetes.csv'
seed = 7
dataset = loadtxt(data_path+pima_file, delimiter=',')
dataset
array([[  6.   , 148.   ,  72.   , ...,   0.627,  50.   ,   1.   ],
       [  1.   ,  85.   ,  66.   , ...,   0.351,  31.   ,   0.   ],
       [  8.   , 183.   ,  64.   , ...,   0.672,  32.   ,   1.   ],
       ...,
       [  5.   , 121.   ,  72.   , ...,   0.245,  30.   ,   0.   ],
       [  1.   , 126.   ,  60.   , ...,   0.349,  47.   ,   1.   ],
       [  1.   ,  93.   ,  70.   , ...,   0.315,  23.   ,   0.   ]])
X = dataset[:, 0:8]
y = dataset[:, 8]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=seed)
model = XGBClassifier()
model.fit(X_train, y_train)
XGBClassifier(base_score=0.5, booster='gbtree', callbacks=None,
              colsample_bylevel=1, colsample_bynode=1, colsample_bytree=1,
              early_stopping_rounds=None, enable_categorical=False,
              eval_metric=None, feature_types=None, gamma=0, gpu_id=-1,
              grow_policy='depthwise', importance_type=None,
              interaction_constraints='', learning_rate=0.300000012,
              max_bin=256, max_cat_threshold=64, max_cat_to_onehot=4,
              max_delta_step=0, max_depth=6, max_leaves=0, min_child_weight=1,
              missing=nan, monotone_constraints='()', n_estimators=100,
              n_jobs=0, num_parallel_tree=1, predictor='auto', random_state=0, ...)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
predictions = model.predict(X_test)
accuracy = accuracy_score(y_test, predictions)
print(f'Accuracy is: {accuracy*100: .2f}')
Accuracy is:  74.02

Visualize Individual Trees Within A Model

from xgboost import plot_tree
from matplotlib import pyplot
plot_tree(model)
pyplot.show()

plot_tree(model, num_trees=0, rankdir='LR')
pyplot.show()

pickle.dump(model, open('models/pima.pickle.dat', 'wb'))
print('Saved model to: pima.pickle.dat')
loadet_model = pickle.load(open('models/pima.pickle.dat', 'rb'))
print('Loaded model from: pima.pickle.dat')
Loaded model from: pima.pickle.dat
predictions = loadet_model.predict(X_test)
accuracy = accuracy_score(y_test, predictions)
print(f'Accuracy is: {accuracy*100 : 0.2f}%')
Accuracy is:  74.02%

Feature Importance With XGBoost and Feature Selection

Manually Plot Feature Importance

print(model.feature_importances_)
[0.08907107 0.23959665 0.08799458 0.09824964 0.09801763 0.15170808
 0.09959181 0.13577053]
pyplot.bar(range(len(model.feature_importances_)), model.feature_importances_)
pyplot.show()

Using the Built-in XGBoost Feature Importance Plot

from xgboost import plot_importance
plot_importance(model)
pyplot.show()

Feature Selection with XGBoost Feature Importance Scores

Monitor Training Performance and Early Stopping

Monitoring Training Performance With XGBoost

eval_set = [(X_test, y_test)]
model = XGBClassifier(eval_metric='error')
model.fit(X_train, y_train, eval_set=eval_set, verbose=True)
[0] validation_0-error:0.28346
[1] validation_0-error:0.25984
[2] validation_0-error:0.25591
[3] validation_0-error:0.24803
[4] validation_0-error:0.24409
[5] validation_0-error:0.24803
[6] validation_0-error:0.25591
[7] validation_0-error:0.24803
[8] validation_0-error:0.25591
[9] validation_0-error:0.24409
[10]    validation_0-error:0.24803
[11]    validation_0-error:0.24409
[12]    validation_0-error:0.23228
[13]    validation_0-error:0.24016
[14]    validation_0-error:0.23622
[15]    validation_0-error:0.24409
[16]    validation_0-error:0.25591
[17]    validation_0-error:0.23622
[18]    validation_0-error:0.24016
[19]    validation_0-error:0.23622
[20]    validation_0-error:0.23622
[21]    validation_0-error:0.23622
[22]    validation_0-error:0.23622
[23]    validation_0-error:0.24409
[24]    validation_0-error:0.24409
[25]    validation_0-error:0.24016
[26]    validation_0-error:0.24409
[27]    validation_0-error:0.24409
[28]    validation_0-error:0.25591
[29]    validation_0-error:0.25197
[30]    validation_0-error:0.24803
[31]    validation_0-error:0.25591
[32]    validation_0-error:0.25591
[33]    validation_0-error:0.25984
[34]    validation_0-error:0.26378
[35]    validation_0-error:0.26378
[36]    validation_0-error:0.26378
[37]    validation_0-error:0.26772
[38]    validation_0-error:0.26378
[39]    validation_0-error:0.25984
[40]    validation_0-error:0.25591
[41]    validation_0-error:0.24409
[42]    validation_0-error:0.24803
[43]    validation_0-error:0.24803
[44]    validation_0-error:0.25591
[45]    validation_0-error:0.25197
[46]    validation_0-error:0.26378
[47]    validation_0-error:0.26378
[48]    validation_0-error:0.26378
[49]    validation_0-error:0.25984
[50]    validation_0-error:0.27165
[51]    validation_0-error:0.26772
[52]    validation_0-error:0.27165
[53]    validation_0-error:0.26772
[54]    validation_0-error:0.26378
[55]    validation_0-error:0.26378
[56]    validation_0-error:0.26378
[57]    validation_0-error:0.26772
[58]    validation_0-error:0.27165
[59]    validation_0-error:0.26772
[60]    validation_0-error:0.27165
[61]    validation_0-error:0.27165
[62]    validation_0-error:0.26772
[63]    validation_0-error:0.26772
[64]    validation_0-error:0.26378
[65]    validation_0-error:0.25984
[66]    validation_0-error:0.27165
[67]    validation_0-error:0.27559
[68]    validation_0-error:0.26772
[69]    validation_0-error:0.26378
[70]    validation_0-error:0.26378
[71]    validation_0-error:0.26772
[72]    validation_0-error:0.26772
[73]    validation_0-error:0.26772
[74]    validation_0-error:0.26772
[75]    validation_0-error:0.26772
[76]    validation_0-error:0.26772
[77]    validation_0-error:0.27165
[78]    validation_0-error:0.26772
[79]    validation_0-error:0.27165
[80]    validation_0-error:0.27165
[81]    validation_0-error:0.28346
[82]    validation_0-error:0.27559
[83]    validation_0-error:0.27165
[84]    validation_0-error:0.27559
[85]    validation_0-error:0.26772
[86]    validation_0-error:0.26772
[87]    validation_0-error:0.26378
[88]    validation_0-error:0.26772
[89]    validation_0-error:0.26378
[90]    validation_0-error:0.27165
[91]    validation_0-error:0.26772
[92]    validation_0-error:0.27165
[93]    validation_0-error:0.26378
[94]    validation_0-error:0.27165
[95]    validation_0-error:0.26378
[96]    validation_0-error:0.25984
[97]    validation_0-error:0.26378
[98]    validation_0-error:0.25984
[99]    validation_0-error:0.25984
XGBClassifier(base_score=0.5, booster='gbtree', callbacks=None,
              colsample_bylevel=1, colsample_bynode=1, colsample_bytree=1,
              early_stopping_rounds=None, enable_categorical=False,
              eval_metric='error', feature_types=None, gamma=0, gpu_id=-1,
              grow_policy='depthwise', importance_type=None,
              interaction_constraints='', learning_rate=0.300000012,
              max_bin=256, max_cat_threshold=64, max_cat_to_onehot=4,
              max_delta_step=0, max_depth=6, max_leaves=0, min_child_weight=1,
              missing=nan, monotone_constraints='()', n_estimators=100,
              n_jobs=0, num_parallel_tree=1, predictor='auto', random_state=0, ...)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
predictions = model.predict(X_test)
accuracy = accuracy_score(y_test, predictions)
print(f'Accuracy: {accuracy*100: .2f}')
Accuracy:  74.02

Evaluate XGBoost Models With Learning Curves

model = XGBClassifier(eval_metric=['error', 'logloss'])
eval_set = [(X_train, y_train), (X_test, y_test)]
model.fit(X_train, y_train, eval_set=eval_set, verbose=True)
[0] validation_0-error:0.13619  validation_0-logloss:0.55257    validation_1-error:0.28346  validation_1-logloss:0.60491
[1] validation_0-error:0.10895  validation_0-logloss:0.46754    validation_1-error:0.25984  validation_1-logloss:0.55934
[2] validation_0-error:0.10506  validation_0-logloss:0.40734    validation_1-error:0.25591  validation_1-logloss:0.53068
[3] validation_0-error:0.09144  validation_0-logloss:0.36480    validation_1-error:0.24803  validation_1-logloss:0.51795
[4] validation_0-error:0.08560  validation_0-logloss:0.33012    validation_1-error:0.24409  validation_1-logloss:0.51153
[5] validation_0-error:0.07782  validation_0-logloss:0.29868    validation_1-error:0.24803  validation_1-logloss:0.50934
[6] validation_0-error:0.06809  validation_0-logloss:0.27852    validation_1-error:0.25591  validation_1-logloss:0.50818
[7] validation_0-error:0.06615  validation_0-logloss:0.26182    validation_1-error:0.24803  validation_1-logloss:0.51097
[8] validation_0-error:0.06226  validation_0-logloss:0.24578    validation_1-error:0.25591  validation_1-logloss:0.51760
[9] validation_0-error:0.05642  validation_0-logloss:0.23298    validation_1-error:0.24409  validation_1-logloss:0.51912
[10]    validation_0-error:0.04669  validation_0-logloss:0.21955    validation_1-error:0.24803  validation_1-logloss:0.52503
[11]    validation_0-error:0.04280  validation_0-logloss:0.21051    validation_1-error:0.24409  validation_1-logloss:0.52697
[12]    validation_0-error:0.03502  validation_0-logloss:0.20083    validation_1-error:0.23228  validation_1-logloss:0.53335
[13]    validation_0-error:0.03696  validation_0-logloss:0.19466    validation_1-error:0.24016  validation_1-logloss:0.53905
[14]    validation_0-error:0.03502  validation_0-logloss:0.18725    validation_1-error:0.23622  validation_1-logloss:0.54545
[15]    validation_0-error:0.02918  validation_0-logloss:0.17765    validation_1-error:0.24409  validation_1-logloss:0.54613
[16]    validation_0-error:0.02724  validation_0-logloss:0.16747    validation_1-error:0.25591  validation_1-logloss:0.54982
[17]    validation_0-error:0.02140  validation_0-logloss:0.15879    validation_1-error:0.23622  validation_1-logloss:0.55226
[18]    validation_0-error:0.01946  validation_0-logloss:0.15115    validation_1-error:0.24016  validation_1-logloss:0.55355
[19]    validation_0-error:0.00973  validation_0-logloss:0.14529    validation_1-error:0.23622  validation_1-logloss:0.55847
[20]    validation_0-error:0.00973  validation_0-logloss:0.14282    validation_1-error:0.23622  validation_1-logloss:0.56063
[21]    validation_0-error:0.00778  validation_0-logloss:0.13959    validation_1-error:0.23622  validation_1-logloss:0.56665
[22]    validation_0-error:0.00778  validation_0-logloss:0.13253    validation_1-error:0.23622  validation_1-logloss:0.57418
[23]    validation_0-error:0.00778  validation_0-logloss:0.12705    validation_1-error:0.24409  validation_1-logloss:0.57448
[24]    validation_0-error:0.00778  validation_0-logloss:0.12430    validation_1-error:0.24409  validation_1-logloss:0.57511
[25]    validation_0-error:0.00584  validation_0-logloss:0.12175    validation_1-error:0.24016  validation_1-logloss:0.58052
[26]    validation_0-error:0.00584  validation_0-logloss:0.11715    validation_1-error:0.24409  validation_1-logloss:0.58830
[27]    validation_0-error:0.00584  validation_0-logloss:0.11200    validation_1-error:0.24409  validation_1-logloss:0.59717
[28]    validation_0-error:0.00195  validation_0-logloss:0.10682    validation_1-error:0.25591  validation_1-logloss:0.60530
[29]    validation_0-error:0.00389  validation_0-logloss:0.10413    validation_1-error:0.25197  validation_1-logloss:0.60871
[30]    validation_0-error:0.00195  validation_0-logloss:0.09942    validation_1-error:0.24803  validation_1-logloss:0.61161
[31]    validation_0-error:0.00195  validation_0-logloss:0.09640    validation_1-error:0.25591  validation_1-logloss:0.61695
[32]    validation_0-error:0.00195  validation_0-logloss:0.09168    validation_1-error:0.25591  validation_1-logloss:0.61717
[33]    validation_0-error:0.00000  validation_0-logloss:0.08941    validation_1-error:0.25984  validation_1-logloss:0.62061
[34]    validation_0-error:0.00195  validation_0-logloss:0.08648    validation_1-error:0.26378  validation_1-logloss:0.61886
[35]    validation_0-error:0.00000  validation_0-logloss:0.08371    validation_1-error:0.26378  validation_1-logloss:0.61903
[36]    validation_0-error:0.00000  validation_0-logloss:0.08277    validation_1-error:0.26378  validation_1-logloss:0.62187
[37]    validation_0-error:0.00000  validation_0-logloss:0.08041    validation_1-error:0.26772  validation_1-logloss:0.62557
[38]    validation_0-error:0.00000  validation_0-logloss:0.07842    validation_1-error:0.26378  validation_1-logloss:0.62663
[39]    validation_0-error:0.00000  validation_0-logloss:0.07651    validation_1-error:0.25984  validation_1-logloss:0.62743
[40]    validation_0-error:0.00000  validation_0-logloss:0.07424    validation_1-error:0.25591  validation_1-logloss:0.62667
[41]    validation_0-error:0.00000  validation_0-logloss:0.07202    validation_1-error:0.24409  validation_1-logloss:0.63148
[42]    validation_0-error:0.00000  validation_0-logloss:0.07012    validation_1-error:0.24803  validation_1-logloss:0.63695
[43]    validation_0-error:0.00000  validation_0-logloss:0.06862    validation_1-error:0.24803  validation_1-logloss:0.64021
[44]    validation_0-error:0.00000  validation_0-logloss:0.06629    validation_1-error:0.25591  validation_1-logloss:0.64323
[45]    validation_0-error:0.00000  validation_0-logloss:0.06394    validation_1-error:0.25197  validation_1-logloss:0.64747
[46]    validation_0-error:0.00000  validation_0-logloss:0.06231    validation_1-error:0.26378  validation_1-logloss:0.64921
[47]    validation_0-error:0.00000  validation_0-logloss:0.06090    validation_1-error:0.26378  validation_1-logloss:0.65250
[48]    validation_0-error:0.00000  validation_0-logloss:0.05953    validation_1-error:0.26378  validation_1-logloss:0.65838
[49]    validation_0-error:0.00000  validation_0-logloss:0.05801    validation_1-error:0.25984  validation_1-logloss:0.66152
[50]    validation_0-error:0.00000  validation_0-logloss:0.05643    validation_1-error:0.27165  validation_1-logloss:0.66584
[51]    validation_0-error:0.00000  validation_0-logloss:0.05549    validation_1-error:0.26772  validation_1-logloss:0.66783
[52]    validation_0-error:0.00000  validation_0-logloss:0.05462    validation_1-error:0.27165  validation_1-logloss:0.67103
[53]    validation_0-error:0.00000  validation_0-logloss:0.05347    validation_1-error:0.26772  validation_1-logloss:0.67425
[54]    validation_0-error:0.00000  validation_0-logloss:0.05253    validation_1-error:0.26378  validation_1-logloss:0.67873
[55]    validation_0-error:0.00000  validation_0-logloss:0.05153    validation_1-error:0.26378  validation_1-logloss:0.67768
[56]    validation_0-error:0.00000  validation_0-logloss:0.05051    validation_1-error:0.26378  validation_1-logloss:0.68269
[57]    validation_0-error:0.00000  validation_0-logloss:0.04942    validation_1-error:0.26772  validation_1-logloss:0.68738
[58]    validation_0-error:0.00000  validation_0-logloss:0.04892    validation_1-error:0.27165  validation_1-logloss:0.69011
[59]    validation_0-error:0.00000  validation_0-logloss:0.04799    validation_1-error:0.26772  validation_1-logloss:0.69266
[60]    validation_0-error:0.00000  validation_0-logloss:0.04720    validation_1-error:0.27165  validation_1-logloss:0.69469
[61]    validation_0-error:0.00000  validation_0-logloss:0.04643    validation_1-error:0.27165  validation_1-logloss:0.70239
[62]    validation_0-error:0.00000  validation_0-logloss:0.04535    validation_1-error:0.26772  validation_1-logloss:0.70504
[63]    validation_0-error:0.00000  validation_0-logloss:0.04454    validation_1-error:0.26772  validation_1-logloss:0.70622
[64]    validation_0-error:0.00000  validation_0-logloss:0.04379    validation_1-error:0.26378  validation_1-logloss:0.70810
[65]    validation_0-error:0.00000  validation_0-logloss:0.04315    validation_1-error:0.25984  validation_1-logloss:0.71247
[66]    validation_0-error:0.00000  validation_0-logloss:0.04241    validation_1-error:0.27165  validation_1-logloss:0.71706
[67]    validation_0-error:0.00000  validation_0-logloss:0.04163    validation_1-error:0.27559  validation_1-logloss:0.71636
[68]    validation_0-error:0.00000  validation_0-logloss:0.04085    validation_1-error:0.26772  validation_1-logloss:0.71625
[69]    validation_0-error:0.00000  validation_0-logloss:0.04036    validation_1-error:0.26378  validation_1-logloss:0.71904
[70]    validation_0-error:0.00000  validation_0-logloss:0.03993    validation_1-error:0.26378  validation_1-logloss:0.72348
[71]    validation_0-error:0.00000  validation_0-logloss:0.03907    validation_1-error:0.26772  validation_1-logloss:0.72573
[72]    validation_0-error:0.00000  validation_0-logloss:0.03835    validation_1-error:0.26772  validation_1-logloss:0.72761
[73]    validation_0-error:0.00000  validation_0-logloss:0.03762    validation_1-error:0.26772  validation_1-logloss:0.72992
[74]    validation_0-error:0.00000  validation_0-logloss:0.03719    validation_1-error:0.26772  validation_1-logloss:0.73336
[75]    validation_0-error:0.00000  validation_0-logloss:0.03669    validation_1-error:0.26772  validation_1-logloss:0.73444
[76]    validation_0-error:0.00000  validation_0-logloss:0.03632    validation_1-error:0.26772  validation_1-logloss:0.73795
[77]    validation_0-error:0.00000  validation_0-logloss:0.03588    validation_1-error:0.27165  validation_1-logloss:0.74054
[78]    validation_0-error:0.00000  validation_0-logloss:0.03521    validation_1-error:0.26772  validation_1-logloss:0.74512
[79]    validation_0-error:0.00000  validation_0-logloss:0.03464    validation_1-error:0.27165  validation_1-logloss:0.74767
[80]    validation_0-error:0.00000  validation_0-logloss:0.03432    validation_1-error:0.27165  validation_1-logloss:0.74878
[81]    validation_0-error:0.00000  validation_0-logloss:0.03380    validation_1-error:0.28346  validation_1-logloss:0.75047
[82]    validation_0-error:0.00000  validation_0-logloss:0.03343    validation_1-error:0.27559  validation_1-logloss:0.75475
[83]    validation_0-error:0.00000  validation_0-logloss:0.03297    validation_1-error:0.27165  validation_1-logloss:0.75587
[84]    validation_0-error:0.00000  validation_0-logloss:0.03245    validation_1-error:0.27559  validation_1-logloss:0.75861
[85]    validation_0-error:0.00000  validation_0-logloss:0.03208    validation_1-error:0.26772  validation_1-logloss:0.75890
[86]    validation_0-error:0.00000  validation_0-logloss:0.03169    validation_1-error:0.26772  validation_1-logloss:0.76230
[87]    validation_0-error:0.00000  validation_0-logloss:0.03139    validation_1-error:0.26378  validation_1-logloss:0.76483
[88]    validation_0-error:0.00000  validation_0-logloss:0.03111    validation_1-error:0.26772  validation_1-logloss:0.76738
[89]    validation_0-error:0.00000  validation_0-logloss:0.03077    validation_1-error:0.26378  validation_1-logloss:0.77021
[90]    validation_0-error:0.00000  validation_0-logloss:0.03041    validation_1-error:0.27165  validation_1-logloss:0.77393
[91]    validation_0-error:0.00000  validation_0-logloss:0.03003    validation_1-error:0.26772  validation_1-logloss:0.77259
[92]    validation_0-error:0.00000  validation_0-logloss:0.02976    validation_1-error:0.27165  validation_1-logloss:0.77214
[93]    validation_0-error:0.00000  validation_0-logloss:0.02947    validation_1-error:0.26378  validation_1-logloss:0.77362
[94]    validation_0-error:0.00000  validation_0-logloss:0.02912    validation_1-error:0.27165  validation_1-logloss:0.77521
[95]    validation_0-error:0.00000  validation_0-logloss:0.02877    validation_1-error:0.26378  validation_1-logloss:0.77405
[96]    validation_0-error:0.00000  validation_0-logloss:0.02853    validation_1-error:0.25984  validation_1-logloss:0.77413
[97]    validation_0-error:0.00000  validation_0-logloss:0.02833    validation_1-error:0.26378  validation_1-logloss:0.77805
[98]    validation_0-error:0.00000  validation_0-logloss:0.02809    validation_1-error:0.25984  validation_1-logloss:0.77660
[99]    validation_0-error:0.00000  validation_0-logloss:0.02787    validation_1-error:0.25984  validation_1-logloss:0.77681
XGBClassifier(base_score=0.5, booster='gbtree', callbacks=None,
              colsample_bylevel=1, colsample_bynode=1, colsample_bytree=1,
              early_stopping_rounds=None, enable_categorical=False,
              eval_metric=['error', 'logloss'], feature_types=None, gamma=0,
              gpu_id=-1, grow_policy='depthwise', importance_type=None,
              interaction_constraints='', learning_rate=0.300000012,
              max_bin=256, max_cat_threshold=64, max_cat_to_onehot=4,
              max_delta_step=0, max_depth=6, max_leaves=0, min_child_weight=1,
              missing=nan, monotone_constraints='()', n_estimators=100,
              n_jobs=0, num_parallel_tree=1, predictor='auto', random_state=0, ...)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
predictions = model.predict(X_test)
accuracy = accuracy_score(y_test, predictions)
print(f'Accuracy: {accuracy*100 : 0.2f}%')
Accuracy:  74.02%
results = model.evals_result()
epochs = len(results['validation_0']['error'])
x_axis = range(0, epochs)
# plot log loss
fig, ax = pyplot.subplots()
ax.plot(x_axis, results['validation_0']['logloss'], label='Train')
ax.plot(x_axis, results['validation_1']['logloss'], label='Test')
ax.legend()
pyplot.ylabel('Log Loss')
pyplot.title('XGBoost Log Loss')
pyplot.show()

# plot classification error
fig, ax = pyplot.subplots()
ax.plot(x_axis, results['validation_0']['error'], label= 'Train')
ax.plot(x_axis, results['validation_1']['error'], label= 'Test')
ax.legend()
pyplot.ylabel('Classification Error')
pyplot.title('XGBoost Classification Error')
pyplot.show()

Early Stopping With XGBoost

model = XGBClassifier(eval_metric='logloss')
eval_set = [(X_test, y_test)]
model.fit(X_train, y_train, early_stopping_rounds=10, eval_set=eval_set, verbose=True)
[0] validation_0-logloss:0.60491
[1] validation_0-logloss:0.55934
[2] validation_0-logloss:0.53068
[3] validation_0-logloss:0.51795
[4] validation_0-logloss:0.51153
[5] validation_0-logloss:0.50934
[6] validation_0-logloss:0.50818
[7] validation_0-logloss:0.51097
[8] validation_0-logloss:0.51760
[9] validation_0-logloss:0.51912
[10]    validation_0-logloss:0.52503
[11]    validation_0-logloss:0.52697
[12]    validation_0-logloss:0.53335
[13]    validation_0-logloss:0.53905
[14]    validation_0-logloss:0.54545
[15]    validation_0-logloss:0.54613
XGBClassifier(base_score=0.5, booster='gbtree', callbacks=None,
              colsample_bylevel=1, colsample_bynode=1, colsample_bytree=1,
              early_stopping_rounds=None, enable_categorical=False,
              eval_metric='logloss', feature_types=None, gamma=0, gpu_id=-1,
              grow_policy='depthwise', importance_type=None,
              interaction_constraints='', learning_rate=0.300000012,
              max_bin=256, max_cat_threshold=64, max_cat_to_onehot=4,
              max_delta_step=0, max_depth=6, max_leaves=0, min_child_weight=1,
              missing=nan, monotone_constraints='()', n_estimators=100,
              n_jobs=0, num_parallel_tree=1, predictor='auto', random_state=0, ...)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Tune Multithreading Support for XGBoost

XGBoost Tuning

Tune the Number and Size of Decision Trees with XGBoost

import matplotlib
from matplotlib import pyplot
from pandas import read_csv
from sklearn.model_selection import GridSearchCV, StratifiedKFold
from sklearn.preprocessing import LabelEncoder
matplotlib.use('Agg')

Tune the Number of Decision Trees

file_path = '/home/naji/Desktop/github-repos/machine-learning/nbs/0-datasets/otto/'
data = read_csv(file_path + 'train.csv')
data
id feat_1 feat_2 feat_3 feat_4 feat_5 feat_6 feat_7 feat_8 feat_9 feat_10 feat_11 feat_12 feat_13 feat_14 feat_15 feat_16 feat_17 feat_18 feat_19 feat_20 feat_21 feat_22 feat_23 feat_24 feat_25 feat_26 feat_27 feat_28 feat_29 feat_30 feat_31 feat_32 feat_33 feat_34 feat_35 feat_36 feat_37 feat_38 feat_39 ... feat_55 feat_56 feat_57 feat_58 feat_59 feat_60 feat_61 feat_62 feat_63 feat_64 feat_65 feat_66 feat_67 feat_68 feat_69 feat_70 feat_71 feat_72 feat_73 feat_74 feat_75 feat_76 feat_77 feat_78 feat_79 feat_80 feat_81 feat_82 feat_83 feat_84 feat_85 feat_86 feat_87 feat_88 feat_89 feat_90 feat_91 feat_92 feat_93 target
0 1 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 2 0 0 0 0 1 0 4 1 1 0 0 2 0 0 0 0 0 1 0 0 0 0 ... 0 0 2 0 0 11 0 1 1 0 1 0 7 0 0 0 1 0 0 0 0 0 0 0 2 1 0 0 0 0 1 0 0 0 0 0 0 0 0 Class_1
1 2 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 ... 0 0 0 1 0 0 0 0 0 1 0 0 1 0 0 0 0 0 2 1 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 Class_1
2 3 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 ... 1 0 0 0 0 0 0 0 0 0 0 0 6 0 0 2 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 Class_1
3 4 1 0 0 1 6 1 5 0 0 1 1 0 1 0 0 1 1 0 0 0 0 0 0 7 2 2 0 0 0 58 0 10 0 0 0 0 0 3 0 ... 1 0 0 0 0 0 0 0 0 0 2 1 5 0 0 4 0 0 2 1 0 1 0 0 1 1 2 2 0 22 0 1 2 0 0 0 0 0 0 Class_1
4 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 4 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 3 0 0 0 0 0 0 0 4 0 1 0 0 0 1 0 0 0 0 1 0 0 0 Class_1
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
61873 61874 1 0 0 1 1 0 0 0 0 0 0 1 3 0 0 9 0 2 0 0 0 7 0 3 6 1 0 0 65 1 0 4 3 1 1 1 2 1 0 ... 3 1 0 0 0 1 0 22 0 1 4 11 3 0 0 3 0 1 1 2 0 0 29 0 0 0 0 0 0 0 1 0 0 0 0 0 0 2 0 Class_9
61874 61875 4 0 0 0 0 0 0 0 0 0 0 0 4 2 0 0 0 0 0 0 0 0 0 0 2 0 0 0 1 0 0 0 2 1 0 2 0 0 0 ... 0 0 0 0 0 0 0 1 2 0 0 1 5 0 0 0 0 0 0 0 0 0 1 0 11 0 0 0 0 0 0 2 0 0 2 0 0 1 0 Class_9
61875 61876 0 0 0 0 0 0 0 3 1 0 0 0 0 1 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 3 0 0 2 0 0 0 ... 0 0 0 0 19 0 0 4 0 0 0 0 18 0 0 0 0 2 0 0 0 2 0 0 0 0 0 0 0 0 0 3 1 0 0 0 0 0 0 Class_9
61876 61877 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 0 0 0 0 0 0 0 2 1 0 0 1 2 0 0 2 1 0 0 5 0 0 0 ... 0 0 1 0 0 0 0 2 0 0 0 0 6 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 3 10 0 Class_9
61877 61878 0 0 0 0 0 0 0 0 0 0 0 0 3 0 1 1 1 1 0 0 0 3 0 2 1 0 0 0 9 0 0 0 0 0 0 2 0 0 0 ... 1 0 1 0 0 3 0 4 0 0 0 0 10 2 0 0 0 0 0 3 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 2 0 Class_9

61878 rows × 95 columns

dataset = data.values
X = dataset[:, 0:94]
y = dataset[:, 94]
label_encoded_y = LabelEncoder().fit_transform(y)
model = XGBClassifier()
n_estimators = range(50, 150, 50)
param_grid = dict(n_estimators=n_estimators)
kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=7)
grid_search = GridSearchCV(model, param_grid, scoring='neg_log_loss', cv=kfold)
grid_result = grid_search.fit(X, label_encoded_y)
KeyboardInterrupt: 
# summarize results
print(f'Best: {}')
means = grid_result.cv_results_['mean_test_score']
stds = grid_result.cv_results_['std_test_score']
params = grid_result.cv_results_['params']
for mean, stdev, param in zip(means, stds, params):
    print(f'{mean} ({stdev}) with {param}')
pyplott.errorbar(n_estimators, means, yerr=stds)
pyplot.title("XGBoost n_estimators vs Log Loss")
pyplot.xlabel('n_estimators')
pyplot.ylabel('Log Loss')
pyplot.savefig('n_estimators.png')

Tune the Size of Decision Trees

model = XGBClassifier()
max_depth = range(1, 5, 2)
param_grid = dict(max_depth=max_depth)
kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=7)
grid_search = GridSearchCV(model, param_grid, scoring='neg_log_loss', cv=kfold, verbose=1)
grid_result = grid_search.fit(X, label_encoded_y)
Fitting 10 folds for each of 2 candidates, totalling 20 fits
KeyboardInterrupt: 
print(f'Best: {grid_result.best_score_} using {grid_result.best_params_}')
# summarize results
print(f'Best: {}')
means = grid_result.cv_results_['mean_test_score']
stds = grid_result.cv_results_['std_test_score']
params = grid_result.cv_results_['params']

Tune The Number and Size of Trees

Tune Learning Rate and Number of Trees with XGBoost

Tuning Stochastic Gradient Boosting with XGBoost