# _*_ coding:utf-8 _*_
import numpy as np
import math
import operator
import sys
import pickle
def createDataSet():
"""
outlook-> 0: sunny | 1: overcast | 2: rain
temperature-> 0: hot | 1: mild | 2: cool
humidity-> 0: high | 1: normal
windy-> 0: false | 1: true
"""
dataSet = [[0, 0, 0, 0, 'N'],
[0, 0, 0, 1, 'N'],
[1, 0, 0, 0, 'Y'],
[2, 1, 0, 0, 'Y'],
[2, 2, 1, 0, 'Y'],
[2, 2, 1, 1, 'N'],
[1, 2, 1, 1, 'Y']]
labels = ['outlook', 'temperature', 'humidity', 'windy']
return dataSet, labels
def creatDataSet1():
# 数据集
dataSet=[[0, 0, 0, 0, 'no'],
[0, 0, 0, 1, 'no'],
[0, 1, 0, 1, 'yes'],
[0, 1, 1, 0, 'yes'],
[0, 0, 0, 0, 'no'],
[1, 0, 0, 0, 'no'],
[1, 0, 0, 1, 'no'],
[1, 1, 1, 1, 'yes'],
[1, 0, 1, 2, 'yes'],
[1, 0, 1, 2, 'yes'],
[2, 0, 1, 2, 'yes'],
[2, 0, 1, 1, 'yes'],
[2, 1, 0, 1, 'yes'],
[2, 1, 0, 2, 'yes'],
[2, 0, 0, 0, 'no']]
#分类属性
labels=['年龄','有工作','有自己的房子','信贷情况']
#返回数据集和分类属性
return dataSet,labels
def cal_entropy(dataset):
#numpy中
# samples = dataset[:,-1]
numbers = len(dataset)
samples = [sample[-1] for sample in dataset]
label_num = {}
for i in samples:
if i not in label_num.keys():
label_num[i]=0
label_num[i]+=1
entropy = 0.0
for k,v in label_num.items():
entropy-=(v/numbers)*(math.log(v/numbers,2))
return entropy
def selectBestFeature(dataset):
print(dataset)
best_index = -1
baseentropy = cal_entropy(dataset)
feat_dic = {}
best_feat = 0.0
for i in range(len(dataset[0])-1):
samples = set([sample[i] for sample in dataset])
feat_entropy = 0.0
for value in samples:
subdataset = splitdataset(dataset,i,value)
feat_entropy += (len(subdataset)/len(dataset))*cal_entropy(subdataset)
feat_dic[i] = baseentropy-feat_entropy
if (feat_dic[i]>best_feat):
best_feat = feat_dic[i]
best_index = i
print(best_index,best_feat)
return best_index
def splitdataset(dataset,i,value):
newdataset = []
for data in dataset:
if data[i]==value:
tmplist = data[:i]
tmplist.extend(data[i+1:])
newdataset.append(tmplist)
# newdataset.append(list(np.delete(data,i)))
return newdataset
def createTree(dataset,labels):
samples = [sample[-1] for sample in dataset]
if len(samples)==samples.count(samples[0]):
return samples[0]
best_index = selectBestFeature(dataset)
print(best_index)
best_label = labels[best_index]
mytree = {best_label:{}}
del (labels[best_index])
values = set([sample[best_index] for sample in dataset])
# print(values)
for val in values:
# print(best_index,val)
data = splitdataset(dataset,best_index,val)
# print(data)
mytree[best_label][val] = createTree(splitdataset(dataset,best_index,val),labels)
print(mytree)
return mytree
def classify(inputTree,featLabel,testVec):
firstStr = next(iter(inputTree))
secondDict = inputTree(firstStr)
featIndex = featLabel.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] ==key:
if type(secondDict[key]).__name__=='dict':
classLabel = classify(secondDict[key],featLabel,testVec)
else:classLabel=secondDict[key]
return classLabel
# def storeTree(input,filename):
# with open(filename,'wb') as fw:
# picket.dump(input,fw)
if __name__ == '__main__':
dataset,labels = creatDataSet1()
mytree = createTree(dataset,labels)
testVec = [0,1]
result = classify(mytree,testVec)