1 #!/usr/bin/python 2 # coding=utf-8 3 from sklearn.datasets import load_iris 4 from sklearn.model_selection import train_test_split 5 from sklearn.tree import DecisionTreeClassifier, export_graphviz 6 def dectree_demo(): 7 #决策树对鸢尾花数据集进行分类 8 9 #获取数据 10 iris = load_iris() 11 12 #划分数据 13 x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target) 14 15 #决策树算法预估器 16 estimator = DecisionTreeClassifier(criterion="entropy") 17 estimator.fit(x_train, y_train) 18 19 #模型评估 20 y_predict = estimator.predict(x_test) 21 print "y_predict:\n", y_predict 22 print "对比真实值和预测值:\n", y_test == y_predict 23 24 # 方法二:计算正确率 25 score = estimator.score(x_test, y_test) 26 print "准确率:\n", score 27 28 #可视化决策树 29 export_graphviz(estimator, out_file="iris_tree.dot", feature_names=iris.feature_names) 30 return None 31 32 dectree_demo()