Introduction:
分类与回归树(classification and regression tree, CART)模型由Breiman等人在1984年提出,CART同样由特征选择、树的生成及剪枝组成,既可以用于分类也可以用于回归,以下简要讨论树生成部分,在随后的博文中再探讨树剪枝的问题。
Algorithm:
step 1. 分别计算所有特征中各个分类的基尼系数 step 2. 选择有最小基尼系数的特征作为最优切分点,因$Gini(D,A_i=j)$最小,所以$A_i=j$作为最优切割点,$A_i$作为根节点
step 3. 在剩余的特征中重复step 1和2,获取最优特征及最优切割点,直至所有特征用尽或者是所有值都一一归类,最后所生成的决策树与ID3算法所生成的完全一致
Formula:
Code:
1 # -*- coding: utf-8 -*- 2 3 FILENAME = ‘decision_tree_text.txt‘ 4 MAXDEPTH = 10 5 6 import numpy as np 7 import plottree 8 9 class Cart(): 10 def __init__(self): 11 self.trainresult = ‘WARNING : please trainDecisionTree first!‘ 12 pass 13 14 15 def trainDecisionTree(self, filename): 16 self.__loadDataSet(filename) 17 self.__optimalTree(self.__datamat) 18 19 20 def __loadDataSet(self, filename): 21 fread = open(filename) 22 self.__dataset = np.array([row.strip().split(‘\t‘) 23 for row in fread.readlines()]) 24 self.__textdic = {} 25 for col in self.__dataset.T: 26 i = .0 27 for cell in col: 28 if not self.__textdic.has_key(cell): 29 self.__textdic[cell] = i 30 i += 1 31 self.__datamat = np.array([np.array([(lambda cell:self.__textdic[cell])(cell) 32 for cell in row]) 33 for row in self.__dataset]) 34 35 36 def __getSampleCount(self, setd, col = -1, s = None): 37 dic = {} 38 39 if s is not None: 40 newset = self.__getSampleMat(setd,col,s)[:,-1] 41 else: 42 newset = setd[:,col] 43 44 for cell in newset: 45 if not dic.has_key(cell): 46 dic[cell] = 1. 47 else: 48 dic[cell] += 1 49 return dic 50 51 52 def __getSampleMat(self, setd, col, s): 53 lista = []; listb = [] 54 for row in setd: 55 if row[col] == s: 56 lista.append(row) 57 else: 58 listb.append(row) 59 return np.array(lista), np.array(listb) 60 61 62 def __getGiniD(self, setd): 63 sample_count = self.__getSampleCount(setd) 64 gini = 0 65 for item in sample_count.items(): 66 gini += item[1]/len(setd) * (1- item[1]/len(setd)) 67 return gini 68 69 70 def __getGiniDA(self, setd, a): 71 sample_count = self.__getSampleCount(setd, a) 72 dic = {} 73 for item in sample_count.items(): 74 setd_part_a, setd_part_b = self.__getSampleMat(setd, a, item[0]) 75 gini = item[1]/len(setd) * self.__getGiniD(setd_part_a) + 76 (1- item[1]/len(setd)) * self.__getGiniD(setd_part_b) 77 dic[item[0]]=gini 78 return min(dic.items()), dic 79 80 81 def __optimalNode(self, setd): 82 coln = 0 83 ginicol = 0 84 mingini = {1:1} 85 for col in setd[:,:-1].T: 86 gini, dic = self.__getGiniDA(setd, coln) 87 if gini[1] < mingini[1]: 88 mingini = gini 89 ginicol = coln 90 coln += 1 91 return ginicol, mingini[0], mingini[1] 92 93 94 def __optimalNodeText(self, col, value): 95 row = 0 96 tex = None 97 for cell in self.__dataset.T[col]: 98 if self.__datamat[row,col] == value: 99 tex = cell 100 break 101 row += 1 102 return tex 103 104 105 def __optimalTree(self, setd): 106 arr = setd 107 count = MAXDEPTH-1 108 features = np.array(range(len(arr.T))) 109 lst = [] 110 defaultc = None 111 while count > 0: 112 count -= 1 113 ginicol, value, gini = self.__optimalNode(arr) 114 parts = self.__getSampleMat(arr, ginicol, value) 115 args = [np.unique(part[:,-1]) for part in parts] 116 realvalues = [np.unique(part[:,ginicol])[0] for part in parts] 117 realcol = features[ginicol] 118 features = np.delete(features, ginicol) 119 if gini == 0 or len(arr.T) == 2: 120 if args[0] == defaultc: 121 value = realvalues[0] 122 else: 123 value = realvalues[1] 124 self.trainresult = self.__buildList(lst, realcol, value, gini) 125 self.__plotTree(self.trainresult) 126 return 127 if len(args[0]) == 1: 128 defaultc = args[0] 129 self.__buildList(lst, realcol, realvalues[0], gini) 130 arr = np.concatenate((parts[1][:,:ginicol], 131 parts[1][:,ginicol+1:]), axis=1) 132 else: 133 defaultc = args[1] 134 self.__buildList(lst, realcol, realvalues[1], gini) 135 arr = np.concatenate((parts[0][:,:ginicol], 136 parts[0][:,ginicol+1:]), axis=1) 137 138 139 def __plotTree(self, lst): 140 dic = {} 141 for item in lst: 142 if dic == {}: 143 dic[item[0]] = {item[1]:‘c1‘,‘ELSE‘:‘c2‘} 144 else: 145 dic = {item[0]:{item[1]:‘c1‘,‘ELSE‘:dic}} 146 tree = plottree.retrieveTree(dic) 147 self.trainresult = tree 148 plottree.createPlot(tree) 149 150 151 def __buildList(self, lst, col, value, gini): 152 print ‘feature col:‘, col, 153 ‘ feature val:‘, self.__optimalNodeText(col, value), 154 ‘ Gini:‘, gini, ‘\n‘ 155 lst.insert(0,[col,str(self.__optimalNodeText(col, 156 value))+‘:‘+str(value)]) 157 return lst 158 159 160 161 if __name__ == ‘__main__‘: 162 cart = Cart()
Reference:
李航. 统计学习方法