IMPLEMENTED IN PYTHON +1 | CART生成树

Introduction:

分类与回归树(classification and regression tree, CART)模型由Breiman等人在1984年提出,CART同样由特征选择、树的生成及剪枝组成,既可以用于分类也可以用于回归,以下简要讨论树生成部分,在随后的博文中再探讨树剪枝的问题。

Algorithm:

IMPLEMENTED IN PYTHON +1 | CART生成树
step 1. 分别计算所有特征中各个分类的基尼系数

step 2. 选择有最小基尼系数的特征作为最优切分点,因$Gini(D,A_i=j)$最小,所以$A_i=j$作为最优切割点,$A_i$作为根节点 

step 3
. 在剩余的特征中重复step 1和2,获取最优特征及最优切割点,直至所有特征用尽或者是所有值都一一归类,最后所生成的决策树与ID3算法所生成的完全一致
IMPLEMENTED IN PYTHON +1 | CART生成树

Formula:

IMPLEMENTED IN PYTHON +1 | CART生成树

Code:

IMPLEMENTED IN PYTHON +1 | CART生成树
  1 # -*- coding: utf-8 -*-
  2 
  3 FILENAME = decision_tree_text.txt
  4 MAXDEPTH = 10
  5 
  6 import numpy as np
  7 import plottree
  8 
  9 class Cart():
 10     def __init__(self):
 11         self.trainresult = WARNING : please trainDecisionTree first!
 12         pass
 13 
 14 
 15     def trainDecisionTree(self, filename):
 16         self.__loadDataSet(filename)
 17         self.__optimalTree(self.__datamat)
 18 
 19 
 20     def __loadDataSet(self, filename):
 21         fread   = open(filename)
 22         self.__dataset = np.array([row.strip().split(\t)  23                   for row in fread.readlines()])
 24         self.__textdic     = {}
 25         for col in self.__dataset.T:
 26             i = .0
 27             for cell in col:
 28                 if not self.__textdic.has_key(cell):
 29                     self.__textdic[cell] = i
 30                     i += 1
 31         self.__datamat = np.array([np.array([(lambda cell:self.__textdic[cell])(cell)  32                   for cell in row])  33                   for row in self.__dataset])
 34 
 35 
 36     def __getSampleCount(self, setd, col = -1, s = None):
 37         dic = {}
 38 
 39         if s is not None:
 40             newset = self.__getSampleMat(setd,col,s)[:,-1]
 41         else:
 42             newset = setd[:,col]
 43 
 44         for cell in newset:
 45             if not dic.has_key(cell):
 46                 dic[cell] =  1.
 47             else:
 48                 dic[cell] += 1
 49         return dic
 50 
 51 
 52     def __getSampleMat(self, setd, col, s):
 53         lista = []; listb = []
 54         for row in setd:
 55             if row[col] == s:
 56                 lista.append(row)
 57             else:
 58                 listb.append(row)
 59         return  np.array(lista), np.array(listb)
 60 
 61 
 62     def __getGiniD(self, setd):
 63         sample_count = self.__getSampleCount(setd)
 64         gini = 0
 65         for item in sample_count.items():
 66             gini += item[1]/len(setd) * (1- item[1]/len(setd))
 67         return gini
 68 
 69 
 70     def __getGiniDA(self, setd, a):
 71         sample_count = self.__getSampleCount(setd, a)
 72         dic = {}
 73         for item in sample_count.items():
 74             setd_part_a, setd_part_b = self.__getSampleMat(setd, a, item[0])
 75             gini = item[1]/len(setd) * self.__getGiniD(setd_part_a) +  76                    (1- item[1]/len(setd)) * self.__getGiniD(setd_part_b)
 77             dic[item[0]]=gini
 78         return min(dic.items()), dic
 79 
 80 
 81     def __optimalNode(self, setd):
 82         coln    = 0
 83         ginicol = 0
 84         mingini = {1:1}
 85         for col in setd[:,:-1].T:
 86             gini, dic = self.__getGiniDA(setd, coln)
 87             if gini[1] < mingini[1]:
 88                 mingini = gini
 89                 ginicol = coln
 90             coln += 1
 91         return ginicol, mingini[0], mingini[1]
 92 
 93 
 94     def __optimalNodeText(self, col, value):
 95         row = 0
 96         tex = None
 97         for cell in self.__dataset.T[col]:
 98             if self.__datamat[row,col] == value:
 99                 tex = cell
100                 break
101             row += 1
102         return tex
103 
104 
105     def __optimalTree(self, setd):
106         arr      = setd
107         count    = MAXDEPTH-1
108         features = np.array(range(len(arr.T)))
109         lst      = []
110         defaultc = None
111         while count > 0:
112             count               -= 1
113             ginicol, value, gini = self.__optimalNode(arr)
114             parts                = self.__getSampleMat(arr, ginicol, value)
115             args                 = [np.unique(part[:,-1]) for part in parts]
116             realvalues           = [np.unique(part[:,ginicol])[0] for part in parts]
117             realcol              = features[ginicol]
118             features             = np.delete(features, ginicol)
119             if gini == 0 or len(arr.T) == 2:
120                 if args[0] == defaultc:
121                     value  = realvalues[0]
122                 else:
123                     value  = realvalues[1]
124                 self.trainresult = self.__buildList(lst, realcol, value, gini)
125                 self.__plotTree(self.trainresult)
126                 return
127             if len(args[0]) == 1:
128                 defaultc = args[0]
129                 self.__buildList(lst, realcol, realvalues[0], gini)
130                 arr = np.concatenate((parts[1][:,:ginicol], 131                       parts[1][:,ginicol+1:]), axis=1)
132             else:
133                 defaultc = args[1]
134                 self.__buildList(lst, realcol, realvalues[1], gini)
135                 arr = np.concatenate((parts[0][:,:ginicol], 136                       parts[0][:,ginicol+1:]), axis=1)
137 
138 
139     def __plotTree(self, lst):
140         dic = {}
141         for item in lst:
142             if dic == {}:
143                 dic[item[0]] = {item[1]:c1,ELSE:c2}
144             else:
145                 dic = {item[0]:{item[1]:c1,ELSE:dic}}
146         tree = plottree.retrieveTree(dic)
147         self.trainresult = tree
148         plottree.createPlot(tree)
149 
150 
151     def __buildList(self, lst, col, value, gini):
152         print feature col:, col, 153                       feature val:, self.__optimalNodeText(col, value), 154                       Gini:, gini, \n
155         lst.insert(0,[col,str(self.__optimalNodeText(col, 156                    value))+:+str(value)])
157         return lst
158 
159 
160 
161 if __name__ == __main__:
162     cart = Cart()
IMPLEMENTED IN PYTHON +1 | CART生成树

Reference:

李航. 统计学习方法

IMPLEMENTED IN PYTHON +1 | CART生成树

上一篇:【python】中文的输出,打印


下一篇:VC++ 汇编相关的东西