决策树对鸢尾花数据集进行分类

 1 #!/usr/bin/python
 2 # coding=utf-8
 3 from sklearn.datasets import load_iris
 4 from sklearn.model_selection import train_test_split
 5 from sklearn.tree import DecisionTreeClassifier, export_graphviz
 6 def dectree_demo():
 7     #决策树对鸢尾花数据集进行分类
 8 
 9     #获取数据
10     iris = load_iris()
11 
12     #划分数据
13     x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target)
14 
15     #决策树算法预估器
16     estimator = DecisionTreeClassifier(criterion="entropy")
17     estimator.fit(x_train, y_train)
18 
19     #模型评估
20     y_predict = estimator.predict(x_test)
21     print "y_predict:\n", y_predict
22     print "对比真实值和预测值:\n", y_test == y_predict
23 
24     # 方法二:计算正确率
25     score = estimator.score(x_test, y_test)
26     print "准确率:\n", score
27 
28     #可视化决策树
29     export_graphviz(estimator, out_file="iris_tree.dot", feature_names=iris.feature_names)
30     return None
31 
32 dectree_demo()

 

上一篇:实验一


下一篇:SAS使用鸢尾花(iris)数据集训练人工神经网络(ANN)模型