本文主要以mllib 1.1版本为基础,分析决策树的基本原理与源码
一、基本原理
二、源码分析
1、决策树构造
指定决策树训练数据集与策略(Strategy)通过train函数就能得到决策树模型DecisionTreeModel
决策树策略包含了:algo(算法类型:分类、回归),impurity(信息增益计算算法)、maxDepth(数最大深度)、
numClassesForClassification(数分类分支数目,为2就是二叉数),maxBins(特征变量最大的分类数目限制)、
quantileCalculationStrategy(分位数计算方法)、categoricalFeaturesInfo(每个特征的分类数目)
2、模型训练
(1)DecisionTree的模型训练函数train主要包含了findSplitsBins、findBestSplits、DecisionTreeModel三部分(入下图所示,为了方便分析,不重要的代码做了删减)
步骤一:findSplitsBins找出数据集中每个变量(Features)对应的所有分裂方式
步骤二:findBestSplits通过计算信息增益来寻找每个节点的最佳的分裂点
步骤三:DecisionTreeModel构造决策树模型
(2)findSplitsBins
通过抽样的方法来近似分位数的计算,抽样样本的的最大数目为maxBins*maxBins
针对每个变量进行迭代,如果是特征是连续变量,先对数据进行排序,然后对数据进行分箱,切成maxBins块,
每块的数目是stride个。对于maxBins块数据则存在maxBins-1种分裂方式。
例如:数据集合data包含1000条数据,两个变量,第一个变量从0-999,第二个变量从999-0
可以看出splits的分裂方式有99种
具体split分裂点的threshold
如果特征变量属于离散变量,又分为2种情况,有序的和无序的
对于无序的离散变量,如果它有n个分类,则分裂的方式就有2^n-1种
如下数据集合包含1000条数据,2个变量。每个变量包含2种分类且是无序的。通过debug可以看出每个变量都有3个split