决策树

决策树之分类树

import pandas as pd
from sklearn import tree
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
from sklearn.datasets import load_wine
import graphviz
wine=load_wine()
datatarget=pd.concat([pd.DataFrame(wine.data),pd.DataFrame(wine.target)],axis=1)
xtrain,xtest,ytrain,ytest=train_test_split(wine.data,wine.target,test_size=0.3)
clf=tree.DecisionTreeClassifier(criterion="gini"
								#默认gini,可有“entropy”,但是容易过拟合,用于拟合不足
                                ,random_state=10
                                ,splitter="best"#可用random
                                #以下参数均为防止过拟合
                                #,max_depth最大深度
                                #,min_samples_leaf最少叶子数
                                #,min_samples_split最小分叉值
                                #,max_features使用最多特征
                                #,min_impurity_decrease最小不纯度
                                ).fit(xtrain,ytrain)
clf.score(xtest,ytest)#测试集打分
cross_val_score(clf,wine.data,wine.target).mean()
dot=tree.export_graphviz(clf
                         #,feature_names    #特征名
                         #,class_names  #结果名
                         ,filled=True#颜色自动填充
                         ,rounded=True)#弧线边界
graph=graphviz.Source(dot)
clf.feature_importances_#模型特征重要性指标
#[*zip(feature_name,clf.feature_importances_)]特征名和重要性结合

决策树之回归树中不同max_depth拟合正弦函数数据

import numpy as np
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt

rng = np.random.RandomState(0)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel()
y[::5] += 3 * (0.5 - rng.rand(16))#添加噪声

#建立不同树模型
regr_1 = DecisionTreeRegressor(max_depth=2)
#回归树除了criterion="mse"其余和分类树参数相同
regr_2 = DecisionTreeRegressor(max_depth=5)
regr_1.fit(X, y)
regr_2.fit(X, y)

# 预测X_test数据在不同数上的表现
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_1 = regr_1.predict(X_test)
y_2 = regr_2.predict(X_test)

#画图
plt.figure()
plt.scatter(X, y, s=20, edgecolor="black",
            c="red", label="data")
plt.plot(X_test, y_1, color="blue",
         label="max_depth=2", linewidth=2)
plt.plot(X_test, y_2, color="green", label="max_depth=5", linewidth=2)
plt.xlabel("data")
plt.ylabel("target")
plt.title("Decision Tree Regression")
plt.legend()
plt.show()

上一篇:Web 端使用Python生成的图片重叠、覆盖


下一篇:Sklearn