推荐阅读:
ID3算法 wiki
决策树算法及实现完整示例代码:
JS简单实现决策树(ID3算法)_demo.html
决策树算法代码实现
1.准备测试数据
这里我假设公司有个小姐姐相亲见面为例
得到以下是已经见面或被淘汰了的数据(部分数据使用mock.js来生成的):
var data =
[
{ "姓名": "余夏", "年龄": 29, "长相": "帅", "体型": "瘦", "收入": "高", 见面: "见" },
{ "姓名": "豆豆", "年龄": 25, "长相": "帅", "体型": "瘦", "收入": "高", 见面: "见" },
{ "姓名": "帅常荣", "年龄": 26, "长相": "帅", "体型": "胖", "收入": "高", 见面: "见" },
{ "姓名": "王涛", "年龄": 22, "长相": "帅", "体型": "瘦", "收入": "高", 见面: "见" },
{ "姓名": "李东", "年龄": 23, "长相": "帅", "体型": "瘦", "收入": "高", 见面: "见" },
{ "姓名": "王五五", "年龄": 23, "长相": "帅", "体型": "瘦", "收入": "低", 见面: "见" },
{ "姓名": "王小涛", "年龄": 22, "长相": "帅", "体型": "瘦", "收入": "低", 见面: "见" },
{ "姓名": "李缤", "年龄": 21, "长相": "帅", "体型": "胖", "收入": "高", 见面: "见" },
{ "姓名": "刘明", "年龄": 21, "长相": "帅", "体型": "胖", "收入": "低", 见面: "不见" },
{ "姓名": "红鹤", "年龄": 21, "长相": "不帅", "体型": "胖", "收入": "高", 见面: "不见" },
{ "姓名": "李理", "年龄": 32, "长相": "帅", "体型": "瘦", "收入": "高", 见面: "不见" },
{ "姓名": "周州", "年龄": 31, "长相": "帅", "体型": "瘦", "收入": "高", 见面: "不见" },
{ "姓名": "李乐", "年龄": 27, "长相": "不帅", "体型": "胖", "收入": "高", 见面: "不见" },
{ "姓名": "韩明", "年龄": 24, "长相": "不帅", "体型": "瘦", "收入": "高", 见面: "不见" },
{ "姓名": "小吕", "年龄": 28, "长相": "帅", "体型": "瘦", "收入": "低", 见面: "不见" },
{ "姓名": "李四", "年龄": 25, "长相": "帅", "体型": "瘦", "收入": "低", 见面: "不见" },
{ "姓名": "王鹏", "年龄": 30, "长相": "帅", "体型": "瘦", "收入": "低", 见面: "不见" },
];
2.搭建决策树基本函数
代码:
function DecisionTree(config) {
if (typeof config == "object" && !Array.isArray(config)) this.training(config);
};
DecisionTree.prototype = {
//分割函数
_predicates: {},
//统计属性值在数据集中的次数
countUniqueValues(items, attr) {},
//获取对象中值最大的Key 假设 counter={a:9,b:2} 得到 "a"
getMaxKey(counter) {},
//寻找最频繁的特定属性值
mostFrequentValue(items, attr) {},
//根据属性切割数据集
split(items, attr, predicate, pivot) {},
//计算熵
entropy(items, attr) {},
//生成决策树
buildDecisionTree(config) {},
//初始化生成决策树
training(config) {},
//预测 测试
predict(data) {},
};
var decisionTree = new DecisionTree();
3.实现函数功能
由于部分函数过于简单我就不进行讲解了
可前往 JS简单实现决策树(ID3算法)_demo.html查看完整代码
里面包含注释,与每个函数的测试方法这里的话我主要讲解下:计算熵的函数、生成决策树函数(信息增益)、与预测函数的实现
在 ID3算法 wiki 中解释了计算熵与信息增益的公式
3.1.计算熵(entropy)函数
根据公式:
我们可以知道计算H(S)(也就是熵)需要得到 p(x)=x/总数量
然后进行计算累加就行了
代码:
//......略
//统计属性值在数据集中的次数
countUniqueValues(items, attr) {
var counter = {}; // 获取不同的结果值 与出现次数
for (var i of items) {
if (!counter[i[attr]]) counter[i[attr]] = 0;
counter[i[attr]] += 1;
}
return counter;
},
//......略
//计算熵
entropy(items, attr) {
var counter = this.countUniqueValues(items, attr); //计算值的出现数
var p, entropy = 0; //H(S)=entropy=∑(P(Xi)(log2(P(Xi))))
for (var i in counter) {
p = counter[i] / items.length; //P(Xi)概率值
entropy += -p * Math.log2(p); //entropy+=-(P(Xi)(log2(P(Xi))))
}
return entropy;
},
//......略
var decisionTree = new DecisionTree();
console.log("函数 countUniqueValues 测试:");
console.log(" 长相", decisionTree.countUniqueValues(data, "长相")); //测试
console.log(" 年龄", decisionTree.countUniqueValues(data, "年龄")); //测试
console.log(" 收入", decisionTree.countUniqueValues(data, "收入")); //测试
console.log("函数 entropy 测试:");
console.log(" 长相", decisionTree.entropy(data, "长相")); //测试
console.log(" 年龄", decisionTree.entropy(data, "年龄")); //测试
console.log(" 收入", decisionTree.entropy(data, "收入")); //测试
3.2.信息增益
根据公式我们知道要得到信息增益的值需要得到:
- H(S) 训练集熵
- p(t)分支元素的占比
- H(t)分支数据集的熵
其中t我们就先分 match(合适的)和on match(不合适),所以H(t):
- H(match) 分割后合适的数据集的熵
- H(on match) 分割后不合适的数据集的熵
所以信息增益G=H(S)-(p(match)H(match)+p(on match)H(on match))
因为p(match)=match数量/数据集总项数量
信息增益G=H(S)-((match数量)xH(match)+(on match数量)xH(on match))/数据集总项数量
//......略
buildDecisionTree(config){
var trainingSet = config.trainingSet;//训练集
var categoryAttr = config.categoryAttr;//用于区分的类别属性
//......略
//初始计算 训练集的熵
var initialEntropy = this.entropy(trainingSet, categoryAttr);//<===H(S)
//......略
var alreadyChecked = [];//标识已经计算过了
var bestSplit = { gain: 0 };//储存当前最佳的分割节点数据信息
//遍历数据集
for (var item of trainingSet) {
// 遍历项中的所有属性
for (var attr in item) {
//跳过区分属性与忽略属性
if ((attr == categoryAttr) || (ignoredAttributes.indexOf(attr) >= 0)) continue;
var pivot = item[attr];// 当前属性的值
var predicateName = ((typeof pivot == 'number') ? '>=' : '=='); //根据数据类型选择判断条件
var attrPredPivot = attr + predicateName + pivot;
if (alreadyChecked.indexOf(attrPredPivot) >= 0) continue;//已经计算过则跳过
alreadyChecked.push(attrPredPivot);//记录
var predicate = this._predicates[predicateName];//匹配分割方式
var currSplit = this.split(trainingSet, attr, predicate, pivot);
var matchEntropy = this.entropy(currSplit.match, categoryAttr);// H(match) 计算分割后合适的数据集的熵
var notMatchEntropy = this.entropy(currSplit.notMatch, categoryAttr);// H(on match) 计算分割后不合适的数据集的熵
//计算信息增益:
// IG(A,S)=H(S)-(∑P(t)H(t)))
// t为分裂的子集match(匹配),on match(不匹配)
// P(match)=match的长度/数据集的长度
// P(on match)=on match的长度/数据集的长度
var iGain = initialEntropy - ((matchEntropy * currSplit.match.length
+ notMatchEntropy * currSplit.notMatch.length) / trainingSet.length);
//不断匹配最佳增益值对应的节点信息
if (iGain > bestSplit.gain) {
//......略
}
}
}
//......递归计算分支
}
3.3.预测功能
预测功能的话就只要将要预测的值传入,循环去寻找符合条件的分支,直到找到最后的所属分类为止,这里就不详细解释了
代码:
//......略
//预测 测试
predict(data) {
var attr, value, predicate, pivot;
var tree = this.root;
while (true) {
if (tree.category) {
return tree.category;
}
attr = tree.attribute;
value = data[attr];
predicate = tree.predicate;
pivot = tree.pivot;
if (predicate(value, pivot)) {
tree = tree.match;
} else {
tree = tree.notMatch;
}
}
}
//......略