import pickle
from numpy import loadtxt
from xgboost import XGBClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
First XGBoost Model
= '/home/naji/Desktop/github-repos/machine-learning/nbs/0-datasets/'
data_path = 'pima-indians-diabetes.csv' pima_file
= 7 seed
= loadtxt(data_path+pima_file, delimiter=',')
dataset dataset
array([[ 6. , 148. , 72. , ..., 0.627, 50. , 1. ],
[ 1. , 85. , 66. , ..., 0.351, 31. , 0. ],
[ 8. , 183. , 64. , ..., 0.672, 32. , 1. ],
...,
[ 5. , 121. , 72. , ..., 0.245, 30. , 0. ],
[ 1. , 126. , 60. , ..., 0.349, 47. , 1. ],
[ 1. , 93. , 70. , ..., 0.315, 23. , 0. ]])
= dataset[:, 0:8]
X = dataset[:, 8] y
= train_test_split(X, y, test_size=0.33, random_state=seed) X_train, X_test, y_train, y_test
= XGBClassifier() model
model.fit(X_train, y_train)
XGBClassifier(base_score=0.5, booster='gbtree', callbacks=None, colsample_bylevel=1, colsample_bynode=1, colsample_bytree=1, early_stopping_rounds=None, enable_categorical=False, eval_metric=None, feature_types=None, gamma=0, gpu_id=-1, grow_policy='depthwise', importance_type=None, interaction_constraints='', learning_rate=0.300000012, max_bin=256, max_cat_threshold=64, max_cat_to_onehot=4, max_delta_step=0, max_depth=6, max_leaves=0, min_child_weight=1, missing=nan, monotone_constraints='()', n_estimators=100, n_jobs=0, num_parallel_tree=1, predictor='auto', random_state=0, ...)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
XGBClassifier(base_score=0.5, booster='gbtree', callbacks=None, colsample_bylevel=1, colsample_bynode=1, colsample_bytree=1, early_stopping_rounds=None, enable_categorical=False, eval_metric=None, feature_types=None, gamma=0, gpu_id=-1, grow_policy='depthwise', importance_type=None, interaction_constraints='', learning_rate=0.300000012, max_bin=256, max_cat_threshold=64, max_cat_to_onehot=4, max_delta_step=0, max_depth=6, max_leaves=0, min_child_weight=1, missing=nan, monotone_constraints='()', n_estimators=100, n_jobs=0, num_parallel_tree=1, predictor='auto', random_state=0, ...)
= model.predict(X_test) predictions
= accuracy_score(y_test, predictions) accuracy
print(f'Accuracy is: {accuracy*100: .2f}')
Accuracy is: 74.02
Visualize Individual Trees Within A Model
from xgboost import plot_tree
from matplotlib import pyplot
plot_tree(model) pyplot.show()
=0, rankdir='LR')
plot_tree(model, num_trees pyplot.show()
open('models/pima.pickle.dat', 'wb'))
pickle.dump(model, print('Saved model to: pima.pickle.dat')
= pickle.load(open('models/pima.pickle.dat', 'rb'))
loadet_model print('Loaded model from: pima.pickle.dat')
Loaded model from: pima.pickle.dat
= loadet_model.predict(X_test) predictions
= accuracy_score(y_test, predictions)
accuracy print(f'Accuracy is: {accuracy*100 : 0.2f}%')
Accuracy is: 74.02%
Feature Importance With XGBoost and Feature Selection
Manually Plot Feature Importance
print(model.feature_importances_)
[0.08907107 0.23959665 0.08799458 0.09824964 0.09801763 0.15170808
0.09959181 0.13577053]
range(len(model.feature_importances_)), model.feature_importances_)
pyplot.bar( pyplot.show()
Using the Built-in XGBoost Feature Importance Plot
from xgboost import plot_importance
plot_importance(model) pyplot.show()
Feature Selection with XGBoost Feature Importance Scores
Monitor Training Performance and Early Stopping
Monitoring Training Performance With XGBoost
= [(X_test, y_test)] eval_set
= XGBClassifier(eval_metric='error') model
=eval_set, verbose=True) model.fit(X_train, y_train, eval_set
[0] validation_0-error:0.28346
[1] validation_0-error:0.25984
[2] validation_0-error:0.25591
[3] validation_0-error:0.24803
[4] validation_0-error:0.24409
[5] validation_0-error:0.24803
[6] validation_0-error:0.25591
[7] validation_0-error:0.24803
[8] validation_0-error:0.25591
[9] validation_0-error:0.24409
[10] validation_0-error:0.24803
[11] validation_0-error:0.24409
[12] validation_0-error:0.23228
[13] validation_0-error:0.24016
[14] validation_0-error:0.23622
[15] validation_0-error:0.24409
[16] validation_0-error:0.25591
[17] validation_0-error:0.23622
[18] validation_0-error:0.24016
[19] validation_0-error:0.23622
[20] validation_0-error:0.23622
[21] validation_0-error:0.23622
[22] validation_0-error:0.23622
[23] validation_0-error:0.24409
[24] validation_0-error:0.24409
[25] validation_0-error:0.24016
[26] validation_0-error:0.24409
[27] validation_0-error:0.24409
[28] validation_0-error:0.25591
[29] validation_0-error:0.25197
[30] validation_0-error:0.24803
[31] validation_0-error:0.25591
[32] validation_0-error:0.25591
[33] validation_0-error:0.25984
[34] validation_0-error:0.26378
[35] validation_0-error:0.26378
[36] validation_0-error:0.26378
[37] validation_0-error:0.26772
[38] validation_0-error:0.26378
[39] validation_0-error:0.25984
[40] validation_0-error:0.25591
[41] validation_0-error:0.24409
[42] validation_0-error:0.24803
[43] validation_0-error:0.24803
[44] validation_0-error:0.25591
[45] validation_0-error:0.25197
[46] validation_0-error:0.26378
[47] validation_0-error:0.26378
[48] validation_0-error:0.26378
[49] validation_0-error:0.25984
[50] validation_0-error:0.27165
[51] validation_0-error:0.26772
[52] validation_0-error:0.27165
[53] validation_0-error:0.26772
[54] validation_0-error:0.26378
[55] validation_0-error:0.26378
[56] validation_0-error:0.26378
[57] validation_0-error:0.26772
[58] validation_0-error:0.27165
[59] validation_0-error:0.26772
[60] validation_0-error:0.27165
[61] validation_0-error:0.27165
[62] validation_0-error:0.26772
[63] validation_0-error:0.26772
[64] validation_0-error:0.26378
[65] validation_0-error:0.25984
[66] validation_0-error:0.27165
[67] validation_0-error:0.27559
[68] validation_0-error:0.26772
[69] validation_0-error:0.26378
[70] validation_0-error:0.26378
[71] validation_0-error:0.26772
[72] validation_0-error:0.26772
[73] validation_0-error:0.26772
[74] validation_0-error:0.26772
[75] validation_0-error:0.26772
[76] validation_0-error:0.26772
[77] validation_0-error:0.27165
[78] validation_0-error:0.26772
[79] validation_0-error:0.27165
[80] validation_0-error:0.27165
[81] validation_0-error:0.28346
[82] validation_0-error:0.27559
[83] validation_0-error:0.27165
[84] validation_0-error:0.27559
[85] validation_0-error:0.26772
[86] validation_0-error:0.26772
[87] validation_0-error:0.26378
[88] validation_0-error:0.26772
[89] validation_0-error:0.26378
[90] validation_0-error:0.27165
[91] validation_0-error:0.26772
[92] validation_0-error:0.27165
[93] validation_0-error:0.26378
[94] validation_0-error:0.27165
[95] validation_0-error:0.26378
[96] validation_0-error:0.25984
[97] validation_0-error:0.26378
[98] validation_0-error:0.25984
[99] validation_0-error:0.25984
XGBClassifier(base_score=0.5, booster='gbtree', callbacks=None, colsample_bylevel=1, colsample_bynode=1, colsample_bytree=1, early_stopping_rounds=None, enable_categorical=False, eval_metric='error', feature_types=None, gamma=0, gpu_id=-1, grow_policy='depthwise', importance_type=None, interaction_constraints='', learning_rate=0.300000012, max_bin=256, max_cat_threshold=64, max_cat_to_onehot=4, max_delta_step=0, max_depth=6, max_leaves=0, min_child_weight=1, missing=nan, monotone_constraints='()', n_estimators=100, n_jobs=0, num_parallel_tree=1, predictor='auto', random_state=0, ...)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
XGBClassifier(base_score=0.5, booster='gbtree', callbacks=None, colsample_bylevel=1, colsample_bynode=1, colsample_bytree=1, early_stopping_rounds=None, enable_categorical=False, eval_metric='error', feature_types=None, gamma=0, gpu_id=-1, grow_policy='depthwise', importance_type=None, interaction_constraints='', learning_rate=0.300000012, max_bin=256, max_cat_threshold=64, max_cat_to_onehot=4, max_delta_step=0, max_depth=6, max_leaves=0, min_child_weight=1, missing=nan, monotone_constraints='()', n_estimators=100, n_jobs=0, num_parallel_tree=1, predictor='auto', random_state=0, ...)
= model.predict(X_test) predictions
= accuracy_score(y_test, predictions)
accuracy print(f'Accuracy: {accuracy*100: .2f}')
Accuracy: 74.02
Evaluate XGBoost Models With Learning Curves
= XGBClassifier(eval_metric=['error', 'logloss']) model
= [(X_train, y_train), (X_test, y_test)] eval_set
=eval_set, verbose=True) model.fit(X_train, y_train, eval_set
[0] validation_0-error:0.13619 validation_0-logloss:0.55257 validation_1-error:0.28346 validation_1-logloss:0.60491
[1] validation_0-error:0.10895 validation_0-logloss:0.46754 validation_1-error:0.25984 validation_1-logloss:0.55934
[2] validation_0-error:0.10506 validation_0-logloss:0.40734 validation_1-error:0.25591 validation_1-logloss:0.53068
[3] validation_0-error:0.09144 validation_0-logloss:0.36480 validation_1-error:0.24803 validation_1-logloss:0.51795
[4] validation_0-error:0.08560 validation_0-logloss:0.33012 validation_1-error:0.24409 validation_1-logloss:0.51153
[5] validation_0-error:0.07782 validation_0-logloss:0.29868 validation_1-error:0.24803 validation_1-logloss:0.50934
[6] validation_0-error:0.06809 validation_0-logloss:0.27852 validation_1-error:0.25591 validation_1-logloss:0.50818
[7] validation_0-error:0.06615 validation_0-logloss:0.26182 validation_1-error:0.24803 validation_1-logloss:0.51097
[8] validation_0-error:0.06226 validation_0-logloss:0.24578 validation_1-error:0.25591 validation_1-logloss:0.51760
[9] validation_0-error:0.05642 validation_0-logloss:0.23298 validation_1-error:0.24409 validation_1-logloss:0.51912
[10] validation_0-error:0.04669 validation_0-logloss:0.21955 validation_1-error:0.24803 validation_1-logloss:0.52503
[11] validation_0-error:0.04280 validation_0-logloss:0.21051 validation_1-error:0.24409 validation_1-logloss:0.52697
[12] validation_0-error:0.03502 validation_0-logloss:0.20083 validation_1-error:0.23228 validation_1-logloss:0.53335
[13] validation_0-error:0.03696 validation_0-logloss:0.19466 validation_1-error:0.24016 validation_1-logloss:0.53905
[14] validation_0-error:0.03502 validation_0-logloss:0.18725 validation_1-error:0.23622 validation_1-logloss:0.54545
[15] validation_0-error:0.02918 validation_0-logloss:0.17765 validation_1-error:0.24409 validation_1-logloss:0.54613
[16] validation_0-error:0.02724 validation_0-logloss:0.16747 validation_1-error:0.25591 validation_1-logloss:0.54982
[17] validation_0-error:0.02140 validation_0-logloss:0.15879 validation_1-error:0.23622 validation_1-logloss:0.55226
[18] validation_0-error:0.01946 validation_0-logloss:0.15115 validation_1-error:0.24016 validation_1-logloss:0.55355
[19] validation_0-error:0.00973 validation_0-logloss:0.14529 validation_1-error:0.23622 validation_1-logloss:0.55847
[20] validation_0-error:0.00973 validation_0-logloss:0.14282 validation_1-error:0.23622 validation_1-logloss:0.56063
[21] validation_0-error:0.00778 validation_0-logloss:0.13959 validation_1-error:0.23622 validation_1-logloss:0.56665
[22] validation_0-error:0.00778 validation_0-logloss:0.13253 validation_1-error:0.23622 validation_1-logloss:0.57418
[23] validation_0-error:0.00778 validation_0-logloss:0.12705 validation_1-error:0.24409 validation_1-logloss:0.57448
[24] validation_0-error:0.00778 validation_0-logloss:0.12430 validation_1-error:0.24409 validation_1-logloss:0.57511
[25] validation_0-error:0.00584 validation_0-logloss:0.12175 validation_1-error:0.24016 validation_1-logloss:0.58052
[26] validation_0-error:0.00584 validation_0-logloss:0.11715 validation_1-error:0.24409 validation_1-logloss:0.58830
[27] validation_0-error:0.00584 validation_0-logloss:0.11200 validation_1-error:0.24409 validation_1-logloss:0.59717
[28] validation_0-error:0.00195 validation_0-logloss:0.10682 validation_1-error:0.25591 validation_1-logloss:0.60530
[29] validation_0-error:0.00389 validation_0-logloss:0.10413 validation_1-error:0.25197 validation_1-logloss:0.60871
[30] validation_0-error:0.00195 validation_0-logloss:0.09942 validation_1-error:0.24803 validation_1-logloss:0.61161
[31] validation_0-error:0.00195 validation_0-logloss:0.09640 validation_1-error:0.25591 validation_1-logloss:0.61695
[32] validation_0-error:0.00195 validation_0-logloss:0.09168 validation_1-error:0.25591 validation_1-logloss:0.61717
[33] validation_0-error:0.00000 validation_0-logloss:0.08941 validation_1-error:0.25984 validation_1-logloss:0.62061
[34] validation_0-error:0.00195 validation_0-logloss:0.08648 validation_1-error:0.26378 validation_1-logloss:0.61886
[35] validation_0-error:0.00000 validation_0-logloss:0.08371 validation_1-error:0.26378 validation_1-logloss:0.61903
[36] validation_0-error:0.00000 validation_0-logloss:0.08277 validation_1-error:0.26378 validation_1-logloss:0.62187
[37] validation_0-error:0.00000 validation_0-logloss:0.08041 validation_1-error:0.26772 validation_1-logloss:0.62557
[38] validation_0-error:0.00000 validation_0-logloss:0.07842 validation_1-error:0.26378 validation_1-logloss:0.62663
[39] validation_0-error:0.00000 validation_0-logloss:0.07651 validation_1-error:0.25984 validation_1-logloss:0.62743
[40] validation_0-error:0.00000 validation_0-logloss:0.07424 validation_1-error:0.25591 validation_1-logloss:0.62667
[41] validation_0-error:0.00000 validation_0-logloss:0.07202 validation_1-error:0.24409 validation_1-logloss:0.63148
[42] validation_0-error:0.00000 validation_0-logloss:0.07012 validation_1-error:0.24803 validation_1-logloss:0.63695
[43] validation_0-error:0.00000 validation_0-logloss:0.06862 validation_1-error:0.24803 validation_1-logloss:0.64021
[44] validation_0-error:0.00000 validation_0-logloss:0.06629 validation_1-error:0.25591 validation_1-logloss:0.64323
[45] validation_0-error:0.00000 validation_0-logloss:0.06394 validation_1-error:0.25197 validation_1-logloss:0.64747
[46] validation_0-error:0.00000 validation_0-logloss:0.06231 validation_1-error:0.26378 validation_1-logloss:0.64921
[47] validation_0-error:0.00000 validation_0-logloss:0.06090 validation_1-error:0.26378 validation_1-logloss:0.65250
[48] validation_0-error:0.00000 validation_0-logloss:0.05953 validation_1-error:0.26378 validation_1-logloss:0.65838
[49] validation_0-error:0.00000 validation_0-logloss:0.05801 validation_1-error:0.25984 validation_1-logloss:0.66152
[50] validation_0-error:0.00000 validation_0-logloss:0.05643 validation_1-error:0.27165 validation_1-logloss:0.66584
[51] validation_0-error:0.00000 validation_0-logloss:0.05549 validation_1-error:0.26772 validation_1-logloss:0.66783
[52] validation_0-error:0.00000 validation_0-logloss:0.05462 validation_1-error:0.27165 validation_1-logloss:0.67103
[53] validation_0-error:0.00000 validation_0-logloss:0.05347 validation_1-error:0.26772 validation_1-logloss:0.67425
[54] validation_0-error:0.00000 validation_0-logloss:0.05253 validation_1-error:0.26378 validation_1-logloss:0.67873
[55] validation_0-error:0.00000 validation_0-logloss:0.05153 validation_1-error:0.26378 validation_1-logloss:0.67768
[56] validation_0-error:0.00000 validation_0-logloss:0.05051 validation_1-error:0.26378 validation_1-logloss:0.68269
[57] validation_0-error:0.00000 validation_0-logloss:0.04942 validation_1-error:0.26772 validation_1-logloss:0.68738
[58] validation_0-error:0.00000 validation_0-logloss:0.04892 validation_1-error:0.27165 validation_1-logloss:0.69011
[59] validation_0-error:0.00000 validation_0-logloss:0.04799 validation_1-error:0.26772 validation_1-logloss:0.69266
[60] validation_0-error:0.00000 validation_0-logloss:0.04720 validation_1-error:0.27165 validation_1-logloss:0.69469
[61] validation_0-error:0.00000 validation_0-logloss:0.04643 validation_1-error:0.27165 validation_1-logloss:0.70239
[62] validation_0-error:0.00000 validation_0-logloss:0.04535 validation_1-error:0.26772 validation_1-logloss:0.70504
[63] validation_0-error:0.00000 validation_0-logloss:0.04454 validation_1-error:0.26772 validation_1-logloss:0.70622
[64] validation_0-error:0.00000 validation_0-logloss:0.04379 validation_1-error:0.26378 validation_1-logloss:0.70810
[65] validation_0-error:0.00000 validation_0-logloss:0.04315 validation_1-error:0.25984 validation_1-logloss:0.71247
[66] validation_0-error:0.00000 validation_0-logloss:0.04241 validation_1-error:0.27165 validation_1-logloss:0.71706
[67] validation_0-error:0.00000 validation_0-logloss:0.04163 validation_1-error:0.27559 validation_1-logloss:0.71636
[68] validation_0-error:0.00000 validation_0-logloss:0.04085 validation_1-error:0.26772 validation_1-logloss:0.71625
[69] validation_0-error:0.00000 validation_0-logloss:0.04036 validation_1-error:0.26378 validation_1-logloss:0.71904
[70] validation_0-error:0.00000 validation_0-logloss:0.03993 validation_1-error:0.26378 validation_1-logloss:0.72348
[71] validation_0-error:0.00000 validation_0-logloss:0.03907 validation_1-error:0.26772 validation_1-logloss:0.72573
[72] validation_0-error:0.00000 validation_0-logloss:0.03835 validation_1-error:0.26772 validation_1-logloss:0.72761
[73] validation_0-error:0.00000 validation_0-logloss:0.03762 validation_1-error:0.26772 validation_1-logloss:0.72992
[74] validation_0-error:0.00000 validation_0-logloss:0.03719 validation_1-error:0.26772 validation_1-logloss:0.73336
[75] validation_0-error:0.00000 validation_0-logloss:0.03669 validation_1-error:0.26772 validation_1-logloss:0.73444
[76] validation_0-error:0.00000 validation_0-logloss:0.03632 validation_1-error:0.26772 validation_1-logloss:0.73795
[77] validation_0-error:0.00000 validation_0-logloss:0.03588 validation_1-error:0.27165 validation_1-logloss:0.74054
[78] validation_0-error:0.00000 validation_0-logloss:0.03521 validation_1-error:0.26772 validation_1-logloss:0.74512
[79] validation_0-error:0.00000 validation_0-logloss:0.03464 validation_1-error:0.27165 validation_1-logloss:0.74767
[80] validation_0-error:0.00000 validation_0-logloss:0.03432 validation_1-error:0.27165 validation_1-logloss:0.74878
[81] validation_0-error:0.00000 validation_0-logloss:0.03380 validation_1-error:0.28346 validation_1-logloss:0.75047
[82] validation_0-error:0.00000 validation_0-logloss:0.03343 validation_1-error:0.27559 validation_1-logloss:0.75475
[83] validation_0-error:0.00000 validation_0-logloss:0.03297 validation_1-error:0.27165 validation_1-logloss:0.75587
[84] validation_0-error:0.00000 validation_0-logloss:0.03245 validation_1-error:0.27559 validation_1-logloss:0.75861
[85] validation_0-error:0.00000 validation_0-logloss:0.03208 validation_1-error:0.26772 validation_1-logloss:0.75890
[86] validation_0-error:0.00000 validation_0-logloss:0.03169 validation_1-error:0.26772 validation_1-logloss:0.76230
[87] validation_0-error:0.00000 validation_0-logloss:0.03139 validation_1-error:0.26378 validation_1-logloss:0.76483
[88] validation_0-error:0.00000 validation_0-logloss:0.03111 validation_1-error:0.26772 validation_1-logloss:0.76738
[89] validation_0-error:0.00000 validation_0-logloss:0.03077 validation_1-error:0.26378 validation_1-logloss:0.77021
[90] validation_0-error:0.00000 validation_0-logloss:0.03041 validation_1-error:0.27165 validation_1-logloss:0.77393
[91] validation_0-error:0.00000 validation_0-logloss:0.03003 validation_1-error:0.26772 validation_1-logloss:0.77259
[92] validation_0-error:0.00000 validation_0-logloss:0.02976 validation_1-error:0.27165 validation_1-logloss:0.77214
[93] validation_0-error:0.00000 validation_0-logloss:0.02947 validation_1-error:0.26378 validation_1-logloss:0.77362
[94] validation_0-error:0.00000 validation_0-logloss:0.02912 validation_1-error:0.27165 validation_1-logloss:0.77521
[95] validation_0-error:0.00000 validation_0-logloss:0.02877 validation_1-error:0.26378 validation_1-logloss:0.77405
[96] validation_0-error:0.00000 validation_0-logloss:0.02853 validation_1-error:0.25984 validation_1-logloss:0.77413
[97] validation_0-error:0.00000 validation_0-logloss:0.02833 validation_1-error:0.26378 validation_1-logloss:0.77805
[98] validation_0-error:0.00000 validation_0-logloss:0.02809 validation_1-error:0.25984 validation_1-logloss:0.77660
[99] validation_0-error:0.00000 validation_0-logloss:0.02787 validation_1-error:0.25984 validation_1-logloss:0.77681
XGBClassifier(base_score=0.5, booster='gbtree', callbacks=None, colsample_bylevel=1, colsample_bynode=1, colsample_bytree=1, early_stopping_rounds=None, enable_categorical=False, eval_metric=['error', 'logloss'], feature_types=None, gamma=0, gpu_id=-1, grow_policy='depthwise', importance_type=None, interaction_constraints='', learning_rate=0.300000012, max_bin=256, max_cat_threshold=64, max_cat_to_onehot=4, max_delta_step=0, max_depth=6, max_leaves=0, min_child_weight=1, missing=nan, monotone_constraints='()', n_estimators=100, n_jobs=0, num_parallel_tree=1, predictor='auto', random_state=0, ...)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
XGBClassifier(base_score=0.5, booster='gbtree', callbacks=None, colsample_bylevel=1, colsample_bynode=1, colsample_bytree=1, early_stopping_rounds=None, enable_categorical=False, eval_metric=['error', 'logloss'], feature_types=None, gamma=0, gpu_id=-1, grow_policy='depthwise', importance_type=None, interaction_constraints='', learning_rate=0.300000012, max_bin=256, max_cat_threshold=64, max_cat_to_onehot=4, max_delta_step=0, max_depth=6, max_leaves=0, min_child_weight=1, missing=nan, monotone_constraints='()', n_estimators=100, n_jobs=0, num_parallel_tree=1, predictor='auto', random_state=0, ...)
= model.predict(X_test) predictions
= accuracy_score(y_test, predictions)
accuracy print(f'Accuracy: {accuracy*100 : 0.2f}%')
Accuracy: 74.02%
= model.evals_result() results
= len(results['validation_0']['error']) epochs
= range(0, epochs) x_axis
# plot log loss
= pyplot.subplots()
fig, ax 'validation_0']['logloss'], label='Train')
ax.plot(x_axis, results['validation_1']['logloss'], label='Test')
ax.plot(x_axis, results[
ax.legend()'Log Loss')
pyplot.ylabel('XGBoost Log Loss')
pyplot.title( pyplot.show()
# plot classification error
= pyplot.subplots()
fig, ax 'validation_0']['error'], label= 'Train')
ax.plot(x_axis, results['validation_1']['error'], label= 'Test')
ax.plot(x_axis, results[
ax.legend()'Classification Error')
pyplot.ylabel('XGBoost Classification Error')
pyplot.title( pyplot.show()
Early Stopping With XGBoost
= XGBClassifier(eval_metric='logloss') model
= [(X_test, y_test)] eval_set
=10, eval_set=eval_set, verbose=True) model.fit(X_train, y_train, early_stopping_rounds
[0] validation_0-logloss:0.60491
[1] validation_0-logloss:0.55934
[2] validation_0-logloss:0.53068
[3] validation_0-logloss:0.51795
[4] validation_0-logloss:0.51153
[5] validation_0-logloss:0.50934
[6] validation_0-logloss:0.50818
[7] validation_0-logloss:0.51097
[8] validation_0-logloss:0.51760
[9] validation_0-logloss:0.51912
[10] validation_0-logloss:0.52503
[11] validation_0-logloss:0.52697
[12] validation_0-logloss:0.53335
[13] validation_0-logloss:0.53905
[14] validation_0-logloss:0.54545
[15] validation_0-logloss:0.54613
XGBClassifier(base_score=0.5, booster='gbtree', callbacks=None, colsample_bylevel=1, colsample_bynode=1, colsample_bytree=1, early_stopping_rounds=None, enable_categorical=False, eval_metric='logloss', feature_types=None, gamma=0, gpu_id=-1, grow_policy='depthwise', importance_type=None, interaction_constraints='', learning_rate=0.300000012, max_bin=256, max_cat_threshold=64, max_cat_to_onehot=4, max_delta_step=0, max_depth=6, max_leaves=0, min_child_weight=1, missing=nan, monotone_constraints='()', n_estimators=100, n_jobs=0, num_parallel_tree=1, predictor='auto', random_state=0, ...)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
XGBClassifier(base_score=0.5, booster='gbtree', callbacks=None, colsample_bylevel=1, colsample_bynode=1, colsample_bytree=1, early_stopping_rounds=None, enable_categorical=False, eval_metric='logloss', feature_types=None, gamma=0, gpu_id=-1, grow_policy='depthwise', importance_type=None, interaction_constraints='', learning_rate=0.300000012, max_bin=256, max_cat_threshold=64, max_cat_to_onehot=4, max_delta_step=0, max_depth=6, max_leaves=0, min_child_weight=1, missing=nan, monotone_constraints='()', n_estimators=100, n_jobs=0, num_parallel_tree=1, predictor='auto', random_state=0, ...)
Tune Multithreading Support for XGBoost
XGBoost Tuning
Tune the Number and Size of Decision Trees with XGBoost
import matplotlib
from matplotlib import pyplot
from pandas import read_csv
from sklearn.model_selection import GridSearchCV, StratifiedKFold
from sklearn.preprocessing import LabelEncoder
'Agg') matplotlib.use(
Tune the Number of Decision Trees
= '/home/naji/Desktop/github-repos/machine-learning/nbs/0-datasets/otto/' file_path
= read_csv(file_path + 'train.csv')
data data
id | feat_1 | feat_2 | feat_3 | feat_4 | feat_5 | feat_6 | feat_7 | feat_8 | feat_9 | feat_10 | feat_11 | feat_12 | feat_13 | feat_14 | feat_15 | feat_16 | feat_17 | feat_18 | feat_19 | feat_20 | feat_21 | feat_22 | feat_23 | feat_24 | feat_25 | feat_26 | feat_27 | feat_28 | feat_29 | feat_30 | feat_31 | feat_32 | feat_33 | feat_34 | feat_35 | feat_36 | feat_37 | feat_38 | feat_39 | ... | feat_55 | feat_56 | feat_57 | feat_58 | feat_59 | feat_60 | feat_61 | feat_62 | feat_63 | feat_64 | feat_65 | feat_66 | feat_67 | feat_68 | feat_69 | feat_70 | feat_71 | feat_72 | feat_73 | feat_74 | feat_75 | feat_76 | feat_77 | feat_78 | feat_79 | feat_80 | feat_81 | feat_82 | feat_83 | feat_84 | feat_85 | feat_86 | feat_87 | feat_88 | feat_89 | feat_90 | feat_91 | feat_92 | feat_93 | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 0 | 1 | 0 | 4 | 1 | 1 | 0 | 0 | 2 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 2 | 0 | 0 | 11 | 0 | 1 | 1 | 0 | 1 | 0 | 7 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 2 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | Class_1 |
1 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | ... | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 2 | 1 | 0 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | Class_1 |
2 | 3 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 6 | 0 | 0 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | Class_1 |
3 | 4 | 1 | 0 | 0 | 1 | 6 | 1 | 5 | 0 | 0 | 1 | 1 | 0 | 1 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 7 | 2 | 2 | 0 | 0 | 0 | 58 | 0 | 10 | 0 | 0 | 0 | 0 | 0 | 3 | 0 | ... | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 2 | 1 | 5 | 0 | 0 | 4 | 0 | 0 | 2 | 1 | 0 | 1 | 0 | 0 | 1 | 1 | 2 | 2 | 0 | 22 | 0 | 1 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | Class_1 |
4 | 5 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 4 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 3 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 4 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | Class_1 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
61873 | 61874 | 1 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 3 | 0 | 0 | 9 | 0 | 2 | 0 | 0 | 0 | 7 | 0 | 3 | 6 | 1 | 0 | 0 | 65 | 1 | 0 | 4 | 3 | 1 | 1 | 1 | 2 | 1 | 0 | ... | 3 | 1 | 0 | 0 | 0 | 1 | 0 | 22 | 0 | 1 | 4 | 11 | 3 | 0 | 0 | 3 | 0 | 1 | 1 | 2 | 0 | 0 | 29 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 2 | 0 | Class_9 |
61874 | 61875 | 4 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 4 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 2 | 1 | 0 | 2 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 2 | 0 | 0 | 1 | 5 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 11 | 0 | 0 | 0 | 0 | 0 | 0 | 2 | 0 | 0 | 2 | 0 | 0 | 1 | 0 | Class_9 |
61875 | 61876 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 3 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 3 | 0 | 0 | 2 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 19 | 0 | 0 | 4 | 0 | 0 | 0 | 0 | 18 | 0 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 3 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | Class_9 |
61876 | 61877 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 3 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 2 | 1 | 0 | 0 | 1 | 2 | 0 | 0 | 2 | 1 | 0 | 0 | 5 | 0 | 0 | 0 | ... | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | 0 | 6 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 3 | 10 | 0 | Class_9 |
61877 | 61878 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 3 | 0 | 1 | 1 | 1 | 1 | 0 | 0 | 0 | 3 | 0 | 2 | 1 | 0 | 0 | 0 | 9 | 0 | 0 | 0 | 0 | 0 | 0 | 2 | 0 | 0 | 0 | ... | 1 | 0 | 1 | 0 | 0 | 3 | 0 | 4 | 0 | 0 | 0 | 0 | 10 | 2 | 0 | 0 | 0 | 0 | 0 | 3 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 2 | 0 | Class_9 |
61878 rows × 95 columns
= data.values dataset
= dataset[:, 0:94]
X = dataset[:, 94] y
= LabelEncoder().fit_transform(y) label_encoded_y
= XGBClassifier() model
= range(50, 150, 50) n_estimators
= dict(n_estimators=n_estimators) param_grid
= StratifiedKFold(n_splits=10, shuffle=True, random_state=7) kfold
= GridSearchCV(model, param_grid, scoring='neg_log_loss', cv=kfold) grid_search
= grid_search.fit(X, label_encoded_y) grid_result
KeyboardInterrupt:
# summarize results
print(f'Best: {}')
= grid_result.cv_results_['mean_test_score']
means = grid_result.cv_results_['std_test_score']
stds = grid_result.cv_results_['params'] params
for mean, stdev, param in zip(means, stds, params):
print(f'{mean} ({stdev}) with {param}')
=stds)
pyplott.errorbar(n_estimators, means, yerr"XGBoost n_estimators vs Log Loss")
pyplot.title('n_estimators')
pyplot.xlabel('Log Loss')
pyplot.ylabel('n_estimators.png') pyplot.savefig(
Tune the Size of Decision Trees
= XGBClassifier() model
= range(1, 5, 2) max_depth
= dict(max_depth=max_depth) param_grid
= StratifiedKFold(n_splits=10, shuffle=True, random_state=7) kfold
= GridSearchCV(model, param_grid, scoring='neg_log_loss', cv=kfold, verbose=1) grid_search
= grid_search.fit(X, label_encoded_y) grid_result
Fitting 10 folds for each of 2 candidates, totalling 20 fits
KeyboardInterrupt:
print(f'Best: {grid_result.best_score_} using {grid_result.best_params_}')
# summarize results
print(f'Best: {}')
= grid_result.cv_results_['mean_test_score']
means = grid_result.cv_results_['std_test_score']
stds = grid_result.cv_results_['params'] params
Tune The Number and Size of Trees