ML之xgboost:利用xgboost算法(sklearn+GridSearchCV)训练mushroom蘑菇数据集(22+1,6513+1611)来预测蘑菇是否毒性(二分类预测)

输出结

ML之xgboost:利用xgboost算法(sklearn+GridSearchCV)训练mushroom蘑菇数据集(22+1,6513+1611)来预测蘑菇是否毒性(二分类预测)

ML之xgboost:利用xgboost算法(sklearn+GridSearchCV)训练mushroom蘑菇数据集(22+1,6513+1611)来预测蘑菇是否毒性(二分类预测)

ML之xgboost:利用xgboost算法(sklearn+GridSearchCV)训练mushroom蘑菇数据集(22+1,6513+1611)来预测蘑菇是否毒性(二分类预测)


 

设计思

ML之xgboost:利用xgboost算法(sklearn+GridSearchCV)训练mushroom蘑菇数据集(22+1,6513+1611)来预测蘑菇是否毒性(二分类预测)

 

核心代

from sklearn.grid_search import GridSearchCV

param_test = { 'n_estimators': range(1, 51, 1)}

clf = GridSearchCV(estimator = bst, param_grid = param_test, cv=5)

clf.fit(X_train, y_train)

clf.grid_scores_, clf.best_params_, clf.best_score_

grid_scores_mean= [0.90542,  0.94749,  0.90542,  0.94749,  0.90573,  0.94718,

                  0.90542,  0.94242,  0.94473,  0.97482,  0.94887,  0.97850,

                  0.97298,  0.97850,  0.97298,  0.97850,  0.97850,  0.97850,

                  0.97850,  0.97850,  0.97850,  0.97850,  0.97850,  0.97850,

                  0.97850,  0.97804,  0.97774,  0.97835,  0.98296,  0.98419,    

                  0.98342,  0.98372,  0.98419,  0.98419,  0.98419,  0.98419,

                  0.98419,  0.98419,  0.98419,  0.98419,  0.98419,  0.98419,

                  0.98419,  0.98419,  0.98419,  0.98419,  0.98419,  0.98419,

                  0.98419 ]

grid_scores_std = [0.08996,  0.07458,  0.08996,  0.07458,  0.09028,  0.07436,  

                  0.08996,  0.07331,  0.07739,  0.02235,  0.07621,  0.02387,  

                  0.03186,  0.02387,  0.03186,  0.02387,  0.02387,  0.02387,  

                  0.02387,  0.02387,  0.02387,  0.02387,  0.02387,  0.02387,  

                  0.02387,  0.02365,  0.02337,  0.02383,  0.01963,  0.02040,  

                  0.01988,  0.02008,  0.02040,  0.02040,  0.02040,  0.02040,  

                  0.02040,  0.02040,  0.02040,  0.02040,  0.02040,  0.02040,  

                  0.02040,  0.02040,  0.02040,  0.02040,  0.02040,  0.02040,  

                  0.02040  ]

#7-CrVa交叉验证曲线可视化

import matplotlib.pyplot as plt

x = range(0,len(grid_scores_mean))

y1 = grid_scores_mean

y2 = grid_scores_std

Xlabel = 'n_estimators'

Ylabel = 'value'

title = 'mushroom datase: xgboost(sklearn+GridSearchCV) model'

plt.plot(x,y1,'r',label='Mean')                      #绘制mean曲线

plt.plot(x,y2,'g',label='Std')                       #绘制std曲线

plt.rcParams['font.sans-serif']=['Times New Roman']  #手动添加中文字体,或者['font.sans-serif'] = ['FangSong']   SimHei

#myfont = matplotlib.font_manager.FontProperties(fname='C:/Windows/Fonts/msyh.ttf')  #也可以指定win系统字体路径

plt.rcParams['axes.unicode_minus'] = False  #对坐标轴的负号进行正常显示

plt.xlabel(Xlabel)

plt.ylabel(Ylabel)

plt.title(title)

plt.legend(loc=1)  

plt.show()


上一篇:XenServer(CloudStack通用)上部署CoreOS


下一篇:解决Access连接 accdb 不可识别的数据库格式异常