西瓜决策树-ID3算法
ID3决策树算法
背景知识
ID3算法最早是由罗斯昆兰(J. Ross Quinlan)于1975年在悉尼大学提出的一种分类预测算法,算法的核心是“信息熵”。ID3算法通过计算每个属性的信息增益,认为信息增益高的是好属性,每次划分选取信息增益最高的属性为划分标准,重复这个过程,直至生成一个能完美分类训练样例的决策树。
决策树是对数据进行分类,以此达到预测的目的。该决策树方法先根据训练集数据形成决策树,如果该树不能对所有对象给出正确的分类,那么选择一些例外加入到训练集数据中,重复该过程一直到形成正确的决策集。决策树代表着决策集的树形结构。
决策树由决策结点、分支和叶子组成。决策树中最上面的结点为根结点,每个分支是一个新的决策结点,或者是树的叶子。每个决策结点代表一个问题或决策,通常对应于待分类对象的属性。每一个叶子结点代表一种可能的分类结果。沿决策树从上到下遍历的过程中,在每个结点都会遇到一个测试,对每个结点上问题的不同的测试输出导致不同的分支,最后会到达一个叶子结点,这个过程就是利用决策树进行分类的过程,利用若干个变量来判断所属的类别。
数据描述
所使用的样本数据有一定的要求,ID3是:
- 描述-属性-值相同的属性必须描述每个例子和有固定数量的价值观。
- 预定义类-实例的属性必须已经定义的,也就是说,他们不是学习的ID3。
- 离散类-类必须是尖锐的鲜明。连续类分解成模糊范畴(如金属被“努力,很困难的,灵活的,温柔的,很软”都是不可信的。
- 足够的例子——因为归纳概括用于(即不可查明)必须选择足够多的测试用例来区分有效模式并消除特殊巧合因素的影响。
ID3决定哪些属性如何是最好的。一个统计特性,被称为信息增益,使用熵得到给定属性衡量培训例子带入目标类分开。信息增益最高的信息(信息是最有益的分类)被选择。为了明确增益,我们首先从信息论借用一个定义,叫做熵。每个属性都有一个熵。
概述
ID3决策树是一种非常重要的用来处理分类问题的结构,它形似一个嵌套N层的IF…ELSE结构,但是它的判断标准不再是一个关系表达式,而是对应的模块的信息增益。它通过信息增益的大小,从根节点开始,选择一个分支,如同进入一个IF结构的statement,通过属性值的取值不同进入新的IF结构的statement,直到到达叶子节点,找到它所属的“分类”标签。
它的流程图是一颗无法保证平衡的多叉树,每一个父节点都是一个判断模块,通过判断,当前的向量会进入它的某一个子节点中,这个子节点是判断模块或者终止模块(叶子节点),当且仅当这个向量到达叶子节点,它也就找到了它的“分类”标签。
ID3决策树通过一个固定的训练集是可以形成一颗永久的“树”的,这课树可以进行保存并且运用到不同的测试集中,唯一的要求就是测试集和训练集需要是结构等价的。这个训练过程就是根据训练集创建规则的过程,这也是机器学习的过程。
实现
数据转换
西瓜数据集.txt:
青绿 蜷缩 浊响 清晰 凹陷 硬滑 是
乌黑 蜷缩 沉闷 清晰 凹陷 硬滑 是
乌黑 蜷缩 浊响 清晰 凹陷 硬滑 是
青绿 蜷缩 沉闷 清晰 凹陷 硬滑 是
浅白 蜷缩 浊响 清晰 凹陷 硬滑 是
青绿 稍蜷 浊响 清晰 稍凹 软粘 是
乌黑 稍蜷 浊响 稍糊 稍凹 软粘 是
乌黑 稍蜷 浊响 清晰 稍凹 硬滑 是
乌黑 稍蜷 沉闷 稍糊 稍凹 硬滑 否
青绿 硬挺 清脆 清晰 平坦 软粘 否
浅白 硬挺 清脆 模糊 平坦 硬滑 否
浅白 蜷缩 浊响 模糊 平坦 软粘 否
青绿 稍蜷 浊响 稍糊 凹陷 硬滑 否
浅白 稍蜷 沉闷 稍糊 凹陷 硬滑 否
乌黑 稍蜷 浊响 清晰 稍凹 软粘 否
浅白 蜷缩 浊响 模糊 平坦 硬滑 否
青绿 蜷缩 沉闷 稍糊 稍凹 硬滑 否
将txt文档转为csv方法:txt转换成csv
引入包
import numpy as np
import pandas as pd
import sklearn.tree as st
import math
import matplotlib
import os
import matplotlib.pyplot as plt
data = pd.read_csv('F:/西瓜数据集.csv',header=None)
data
熵
def calcEntropy(dataSet):
mD = len(dataSet)
dataLabelList = [x[-1] for x in dataSet]
dataLabelSet = set(dataLabelList)
ent = 0
for label in dataLabelSet:
mDv = dataLabelList.count(label)
prop = float(mDv) / mD
ent = ent - prop * np.math.log(prop, 2)
return ent
拆分数据集
# index - 要拆分的特征的下标
# feature - 要拆分的特征
# 返回值 - dataSet中index所在特征为feature,且去掉index一列的集合
def splitDataSet(dataSet, index, feature):
splitedDataSet = []
mD = len(dataSet)
for data in dataSet:
if(data[index] == feature):
sliceTmp = data[:index]
sliceTmp.extend(data[index + 1:])
splitedDataSet.append(sliceTmp)
return splitedDataSet
选择最优特征
# 返回值 - 最好的特征的下标
def chooseBestFeature(dataSet):
entD = calcEntropy(dataSet)
mD = len(dataSet)
featureNumber = len(dataSet[0]) - 1
maxGain = -100
maxIndex = -1
for i in range(featureNumber):
entDCopy = entD
featureI = [x[i] for x in dataSet]
featureSet = set(featureI)
for feature in featureSet:
splitedDataSet = splitDataSet(dataSet, i, feature) # 拆分数据集
mDv = len(splitedDataSet)
entDCopy = entDCopy - float(mDv) / mD * calcEntropy(splitedDataSet)
if(maxIndex == -1):
maxGain = entDCopy
maxIndex = i
elif(maxGain < entDCopy):
maxGain = entDCopy
maxIndex = i
return maxIndex
寻找最多作为标签
# 返回值 - 标签
def mainLabel(labelList):
labelRec = labelList[0]
maxLabelCount = -1
labelSet = set(labelList)
for label in labelSet:
if(labelList.count(label) > maxLabelCount):
maxLabelCount = labelList.count(label)
labelRec = label
return labelRec
生成树
def createFullDecisionTree(dataSet, featureNames, featureNamesSet, labelListParent):
labelList = [x[-1] for x in dataSet]
if(len(dataSet) == 0):
return mainLabel(labelListParent)
elif(len(dataSet[0]) == 1): #没有可划分的属性了
return mainLabel(labelList) #选出最多的label作为该数据集的标签
elif(labelList.count(labelList[0]) == len(labelList)): # 全部都属于同一个Label
return labelList[0]
bestFeatureIndex = chooseBestFeature(dataSet)
bestFeatureName = featureNames.pop(bestFeatureIndex)
myTree = {bestFeatureName: {}}
featureList = featureNamesSet.pop(bestFeatureIndex)
featureSet = set(featureList)
for feature in featureSet:
featureNamesNext = featureNames[:]
featureNamesSetNext = featureNamesSet[:][:]
splitedDataSet = splitDataSet(dataSet, bestFeatureIndex, feature)
myTree[bestFeatureName][feature] = createFullDecisionTree(splitedDataSet, featureNamesNext, featureNamesSetNext, labelList)
return myTree
初始化
# 返回值
# dataSet 数据集
# featureNames 标签
# featureNamesSet 列标签
def readWatermelonDataSet():
dataSet = data.values.tolist()
featureNames =['色泽', '根蒂', '敲击', '纹理', '脐部', '触感']
#获取featureNamesSet
featureNamesSet = []
for i in range(len(dataSet[0]) - 1):
col = [x[i] for x in dataSet]
colSet = set(col)
featureNamesSet.append(list(colSet))
return dataSet, featureNames, featureNamesSet
画图
# 能够显示中文
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['font.serif'] = ['SimHei']
# 分叉节点,也就是决策节点
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
# 叶子节点
leafNode = dict(boxstyle="round4", fc="0.8")
# 箭头样式
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
"""
绘制一个节点
:param nodeTxt: 描述该节点的文本信息
:param centerPt: 文本的坐标
:param parentPt: 点的坐标,这里也是指父节点的坐标
:param nodeType: 节点类型,分为叶子节点和决策节点
:return:
"""
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def getNumLeafs(myTree):
"""
获取叶节点的数目
:param myTree:
:return:
"""
# 统计叶子节点的总数
numLeafs = 0
# 得到当前第一个key,也就是根节点
firstStr = list(myTree.keys())[0]
# 得到第一个key对应的内容
secondDict = myTree[firstStr]
# 递归遍历叶子节点
for key in secondDict.keys():
# 如果key对应的是一个字典,就递归调用
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
# 不是的话,说明此时是一个叶子节点
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
"""
得到数的深度层数
:param myTree:
:return:
"""
# 用来保存最大层数
maxDepth = 0
# 得到根节点
firstStr = list(myTree.keys())[0]
# 得到key对应的内容
secondDic = myTree[firstStr]
# 遍历所有子节点
for key in secondDic.keys():
# 如果该节点是字典,就递归调用
if type(secondDic[key]).__name__ == 'dict':
# 子节点的深度加1
thisDepth = 1 + getTreeDepth(secondDic[key])
# 说明此时是叶子节点
else:
thisDepth = 1
# 替换最大层数
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
def plotMidText(cntrPt, parentPt, txtString):
"""
计算出父节点和子节点的中间位置,填充信息
:param cntrPt: 子节点坐标
:param parentPt: 父节点坐标
:param txtString: 填充的文本信息
:return:
"""
# 计算x轴的中间位置
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
# 计算y轴的中间位置
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
# 进行绘制
createPlot.ax1.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt):
"""
绘制出树的所有节点,递归绘制
:param myTree: 树
:param parentPt: 父节点的坐标
:param nodeTxt: 节点的文本信息
:return:
"""
# 计算叶子节点数
numLeafs = getNumLeafs(myTree=myTree)
# 计算树的深度
depth = getTreeDepth(myTree=myTree)
# 得到根节点的信息内容
firstStr = list(myTree.keys())[0]
# 计算出当前根节点在所有子节点的中间坐标,也就是当前x轴的偏移量加上计算出来的根节点的中心位置作为x轴(比如说第一次:初始的x偏移量为:-1/2W,计算出来的根节点中心位置为:(1+W)/2W,相加得到:1/2),当前y轴偏移量作为y轴
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
# 绘制该节点与父节点的联系
plotMidText(cntrPt, parentPt, nodeTxt)
# 绘制该节点
plotNode(firstStr, cntrPt, parentPt, decisionNode)
# 得到当前根节点对应的子树
secondDict = myTree[firstStr]
# 计算出新的y轴偏移量,向下移动1/D,也就是下一层的绘制y轴
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
# 循环遍历所有的key
for key in secondDict.keys():
# 如果当前的key是字典的话,代表还有子树,则递归遍历
if isinstance(secondDict[key], dict):
plotTree(secondDict[key], cntrPt, str(key))
else:
# 计算新的x轴偏移量,也就是下个叶子绘制的x轴坐标向右移动了1/W
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
# 打开注释可以观察叶子节点的坐标变化
# print((plotTree.xOff, plotTree.yOff), secondDict[key])
# 绘制叶子节点
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
# 绘制叶子节点和父节点的中间连线内容
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
# 返回递归之前,需要将y轴的偏移量增加,向上移动1/D,也就是返回去绘制上一层的y轴
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def createPlot(inTree):
"""
需要绘制的决策树
:param inTree: 决策树字典
:return:
"""
# 创建一个图像
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
# 计算出决策树的总宽度
plotTree.totalW = float(getNumLeafs(inTree))
# 计算出决策树的总深度
plotTree.totalD = float(getTreeDepth(inTree))
# 初始的x轴偏移量,也就是-1/2W,每次向右移动1/W,也就是第一个叶子节点绘制的x坐标为:1/2W,第二个:3/2W,第三个:5/2W,最后一个:(W-1)/2W
plotTree.xOff = -0.5/plotTree.totalW
# 初始的y轴偏移量,每次向下或者向上移动1/D
plotTree.yOff = 1.0
# 调用函数进行绘制节点图像
plotTree(inTree, (0.5, 1.0), '')
# 绘制
plt.show()
结果
dataSet, featureNames, featureNamesSet=readWatermelonDataSet()
testTree= createFullDecisionTree(dataSet, featureNames, featureNamesSet,featureNames)
createPlot(testTree)