机器学习入门(七):分类算法——决策树算法

学习目录:
机器学习入门(七):分类算法——决策树算法
决策树内容目录:
机器学习入门(七):分类算法——决策树算法

一.决策树作用:

机器学习入门(七):分类算法——决策树算法
   这是我们判断这是个好瓜还是坏瓜的决策流程,决策树的作用:
1.帮助我们选择用哪个特征先做if,用哪个特征后做if,能最快的判断出这是好瓜还是坏瓜
2.帮助我们确定特征中作为划分标准的数值

二.原理推导

三.代码预测:

机器学习入门(七):分类算法——决策树算法

案例对比:比较决策树算法和KNN算法在鸢尾花数据集上的分类准确率

使用决策树算法对鸢尾花数据集分类:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
def dicision_iris():
   """使用决策树对鸢尾花数据集进行分类
   :return:"""
   # 获取数据
   iris = load_iris()
   # 划分数据集
   x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=22)
   #决策树估计器
   estimator=DecisionTreeClassifier(criterion="entropy")
   estimator.fit(x_train, y_train)  # 把训练数据放进去
   #模型评估
   #方法一:直接比对真实值和预测值
   y_predict=estimator.predict(x_test)
   print('y_predict:\n',y_predict)
   print('直接比对真实值和预测值:\n', y_test==y_predict)
   # 方法二:计算准确率
   score = estimator.score(x_test,y_test)
   print('准确率:\n', score)
   
if __name__=='__main__':
    dicision_iris()

机器学习入门(七):分类算法——决策树算法
使用KNN算法对鸢尾花数据集分类:
机器学习入门(七):分类算法——决策树算法
结论:KNN算法的准确率更高,因为KNN算法本来就适合小数据集(鸢尾花数据集只有150个样本),他去一个一个计算距离。而决策树更适合大数据集。--------不同算法适合不同场景

四.决策树可视化

机器学习入门(七):分类算法——决策树算法
机器学习入门(七):分类算法——决策树算法
生成的dot文件:

digraph Tree {
node [shape=box] ;
0 [label="petal width (cm) <= 0.75\nentropy = 1.584\nsamples = 112\nvalue = [39, 37, 36]"] ;
1 [label="entropy = 0.0\nsamples = 39\nvalue = [39, 0, 0]"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="petal width (cm) <= 1.75\nentropy = 1.0\nsamples = 73\nvalue = [0, 37, 36]"] ;
0 -> 2 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
3 [label="petal length (cm) <= 5.05\nentropy = 0.391\nsamples = 39\nvalue = [0, 36, 3]"] ;
2 -> 3 ;
4 [label="sepal length (cm) <= 4.95\nentropy = 0.183\nsamples = 36\nvalue = [0, 35, 1]"] ;
3 -> 4 ;
5 [label="petal length (cm) <= 3.9\nentropy = 1.0\nsamples = 2\nvalue = [0, 1, 1]"] ;
4 -> 5 ;
6 [label="entropy = 0.0\nsamples = 1\nvalue = [0, 1, 0]"] ;
5 -> 6 ;
7 [label="entropy = 0.0\nsamples = 1\nvalue = [0, 0, 1]"] ;
5 -> 7 ;
8 [label="entropy = 0.0\nsamples = 34\nvalue = [0, 34, 0]"] ;
4 -> 8 ;
9 [label="petal width (cm) <= 1.55\nentropy = 0.918\nsamples = 3\nvalue = [0, 1, 2]"] ;
3 -> 9 ;
10 [label="entropy = 0.0\nsamples = 2\nvalue = [0, 0, 2]"] ;
9 -> 10 ;
11 [label="entropy = 0.0\nsamples = 1\nvalue = [0, 1, 0]"] ;
9 -> 11 ;
12 [label="petal length (cm) <= 4.85\nentropy = 0.191\nsamples = 34\nvalue = [0, 1, 33]"] ;
2 -> 12 ;
13 [label="entropy = 0.0\nsamples = 1\nvalue = [0, 1, 0]"] ;
12 -> 13 ;
14 [label="entropy = 0.0\nsamples = 33\nvalue = [0, 0, 33]"] ;
12 -> 14 ;
}

机器学习入门(七):分类算法——决策树算法
把dot文件内容复制放进这个网站:
机器学习入门(七):分类算法——决策树算法

五.决策树算法总结:

优点:
          可视化—可解释能力强
缺点:
          当数据量非常大时,决策树会过于复杂,容易导致对训练样本好,测试样本差(过拟合)
如何改进?
          减枝cart算法(决策树API中已实现)
          随机森林算法(下一节讲)

六.案例:泰坦尼克号乘客生存预测

机器学习入门(七):分类算法——决策树算法
机器学习入门(七):分类算法——决策树算法
提取出来仓位、年龄、性别,其他的不影响他是否生存:
机器学习入门(七):分类算法——决策树算法
对年龄里的缺失值进行处理,填入平均值:
机器学习入门(七):分类算法——决策树算法
转换成字典格式:
机器学习入门(七):分类算法——决策树算法划分数据集、字典特征抽取、决策树预估器、评估:
机器学习入门(七):分类算法——决策树算法
结果:
机器学习入门(七):分类算法——决策树算法

import pandas as pd
import graphviz
from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction import DictVectorizer

# 获取数据
path = "http://biostat.mc.vanderbilt.edu/wiki/pub/Main/DataSets/titanic.txt"
titanic = pd.read_csv(path)

# 获取特征值与目标值
x = titanic[['pclass', 'age', 'sex']]
y = titanic['survived']

# 处理数据
x['age'].fillna(x['age'].mean(), inplace=True)
# 转换成字典
x = x.to_dict(orient='recorda')
# 划分数据
x_train, x_test, y_train, y_test = train_test_split(x, y)

# 字典特征提取
transfer = DictVectorizer()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)

# 构建模型
estimator = DecisionTreeClassifier(criterion='entropy')
estimator.fit(x_train, y_train)
y_predict = estimator.predict(x_test)
print('预测值与真实值比对:\n', y_predict == y_test)
# 求准确率
score = model.score(x_test, y_test)
print('准确率:\n', score)
# 可视化决策树
image = export_graphviz(
      estimator,
      out_file="C:/Users/Admin/Desktop/iris_tree.dot",
      feature_names=transfer.get_feature_names(),
)

上一篇:OS + Linux Kali / Debian BackTrack


下一篇:机器学习-决策树之ID3算法