java实现gbdt

DATA类

import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.Scanner; public class Data {
private ArrayList<ArrayList<String>> trainData=new ArrayList<ArrayList<String>>();
public ArrayList<ArrayList<String>> getTrainData() {
return this.trainData;
} public Data() {
String dataPath="D://javajavajava//dbdt//src//script//data//adult.data.csv";
Scanner in;
try {
in = new Scanner(new File(dataPath));
while (in.hasNext()) {
String line=in.nextLine();
String []strs=line.trim().split(",");
ArrayList<String> tmp=new ArrayList<>();
for(int i=0;i<strs.length;i++)
{
tmp.add(strs[i]);
}
this.trainData.add(tmp);
}
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} } public static void main(String[] args) {
// TODO Auto-generated method stub
Data d =new Data(); } }

  TREE类

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import java.util.spi.TimeZoneNameProvider; public class Tree {
private Tree leftTree=new Tree();
private Tree rightTree=new Tree();
private double loss=-1;
private int attributeSplit=0;
private String attributeSplitType="";
boolean isLeaf;
double leafValue;
private ArrayList<Integer> leafNodeSet=new ArrayList<>(); public ArrayList<String> getAttributeSet(ArrayList<ArrayList<String>> trainData,int idx)
{
HashSet<String> mySet=new HashSet<>();
ArrayList<String> ans =new ArrayList<>();
for(int i=0;i<trainData.size();i++)
{
mySet.add(trainData.get(i).get(idx));
} Iterator<String> it=mySet.iterator(); while(it.hasNext())
{
ans.add(it.next());
} return ans;
}
public boolean myCmpLess(String str1,String str2)
{
if(Integer.parseInt(str1.trim())<=Integer.parseInt(str2.trim()))
return true;
else return false; }
public double computeLoss(ArrayList<Double> values)
{
double loss=0;
for(int i=0;i<values.size();i++)
{
loss+=values.get(i);
}
double mean=loss/values.size();
loss=0;
for(int i=0;i<values.size();i++)
{
loss+=Math.pow(values.get(i)-mean,2);
}
return Math.sqrt(loss);
}
public double getPredictValue(int K, ArrayList<Integer> subIdx,ArrayList<Double> target) {
double ans=0;
double sum=0,sum1=0;
for(int i=0;i<subIdx.size();i++)
{
sum+=target.get(subIdx.get(i));
}
for(int i=0;i<subIdx.size();i++)
{
sum1+=target.get(subIdx.get(i))*(1-target.get(subIdx.get(i)));
}
ans=(K-1)/K*sum/sum1;
return ans;
}
public double getPredictValue(Tree root)
{
return root.leafValue;
}
public double getPredictValue(Tree root,ArrayList<String> instance,Boolean isDigit[])
{ if(root.isLeaf)
return root.leafValue;
else if(isDigit[root.attributeSplit])
{
if(myCmpLess(instance.get(root.attributeSplit).trim(),root.attributeSplitType))
return getPredictValue(root.leftTree, instance, isDigit);
return getPredictValue(root.rightTree, instance, isDigit);
}
else
{
if(instance.get(root.attributeSplit).trim().equals(root.attributeSplitType))
return getPredictValue(root.leftTree, instance, isDigit);
return getPredictValue(root.rightTree, instance, isDigit);
} }
public Tree constructTree(ArrayList<ArrayList<Integer>> leafNodes,ArrayList<Double> leafValues,int K,int splitPoints, Boolean isDigit[],ArrayList<Integer> subIdx,ArrayList<ArrayList<String>> trainData,ArrayList<Double> target,int maxDepth[],int depth)
{ int n=trainData.size();
int dim=trainData.get(0).size();
ArrayList<Integer> leftTreeIdx=new ArrayList<>();
ArrayList<Integer> rightTreeIdx=new ArrayList<>(); if(depth<maxDepth[0])
{
/*
* 从所有的attribute中选取最佳的attribute,并且attribute中最佳的分割点,对数据进行分割
* */
double loss=-1;
ArrayList<Integer> leftNodes=new ArrayList<>();
ArrayList<Integer> rightNodes=new ArrayList<>();
int attributeSplit=0;
String attributeSplitType=""; for(int i=0;i<dim;i++)//遍历所有的attribute
{
//得到该attribute下所有的distinct的值
ArrayList<String> myAttributeSet=new ArrayList<>();
ArrayList<String> subDigitAttribute=new ArrayList<>();
myAttributeSet=getAttributeSet(trainData, i);
if(isDigit[i])//如果是数字,就从数组中随机选取splitpoints个节点,代表这个属性可以在这splitpoints下进行分割
{
while(subDigitAttribute.size()<splitPoints)
{
Random r=new Random();
int tmp=r.nextInt(myAttributeSet.size());
subDigitAttribute.add(myAttributeSet.get(tmp));
myAttributeSet.clear();
myAttributeSet=subDigitAttribute;
}
}
for(int j=0;j<myAttributeSet.size();j++)
{
for(int k=0;k<subIdx.size();k++)
{
if((!isDigit[i]&&trainData.get(subIdx.get(k)).get(i).trim().equals(myAttributeSet.get(j)))||(isDigit[i]&&myCmpLess(trainData.get(subIdx.get(k)).get(i),myAttributeSet.get(j))))
{
leftTreeIdx.add(subIdx.get(k));
}
else
{
rightTreeIdx.add(subIdx.get(k));
}
}
ArrayList<Double> leftTarget=new ArrayList<>();
ArrayList<Double> rightTarget=new ArrayList<>();
for(int k=0;k<leftTreeIdx.size();k++)
leftTarget.add(target.get(leftTreeIdx.get(k)));
for(int k=0;k<rightTreeIdx.size();k++)
rightTarget.add(target.get(rightTreeIdx.get(k)));
double lossTmp=computeLoss(leftTarget)+computeLoss(rightTarget);
if(loss<0||loss<lossTmp)
{
leftNodes.clear();
rightNodes.clear();
for(int k=0;k<leftTreeIdx.size();k++)
leftNodes.add(leftTreeIdx.get(k));
for(int k=0;k<rightTreeIdx.size();k++)
rightNodes.add(rightTreeIdx.get(k));
attributeSplit=i;
attributeSplitType=myAttributeSet.get(j);
} } } Tree tmpTree=new Tree();
tmpTree.attributeSplit=attributeSplit;
tmpTree.attributeSplitType=attributeSplitType;
tmpTree.loss=loss;
tmpTree.isLeaf=false;
tmpTree.leftTree=constructTree(leafNodes,leafValues,K,splitPoints, isDigit, leftNodes, trainData, target, maxDepth, depth+1);
tmpTree.leftTree=constructTree(leafNodes,leafValues,K,splitPoints, isDigit, rightNodes, trainData, target, maxDepth, depth+1);
return tmpTree; }
else
{
Tree tmpTree=new Tree();
tmpTree.isLeaf=true;
tmpTree.leafValue=getPredictValue(K, subIdx, target);
for(int i=0;i<subIdx.size();i++)
tmpTree.leafNodeSet.add(subIdx.get(i));
leafNodes.add(subIdx);
leafValues.add(tmpTree.leafValue);
return tmpTree;
}
} public static void main(String[] args) {
// TODO Auto-generated method stub
Tree aTree=new Tree();
} }

  

GBDT类

import java.rmi.server.SkeletonNotFoundException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.Set; public class GBDT { private ArrayList<ArrayList<String>> datas=new ArrayList<ArrayList<String>>();
private ArrayList<String> labelSets=new ArrayList<>();
private ArrayList<ArrayList<Double>> F=new ArrayList<ArrayList<Double>>();
private ArrayList<ArrayList<Double>> residual=new ArrayList<ArrayList<Double>>();
private ArrayList<ArrayList<String>> trainData=new ArrayList<ArrayList<String>>();
private ArrayList<Integer> labelTrainData=new ArrayList<Integer>();
private int K;
private Boolean isDigit[];
private int dim;
private int n;
private double learningRate; private ArrayList<ArrayList<Tree>> trees=new ArrayList<ArrayList<Tree>>(); //存放所有的树 private int max_iter;
private double sampleRate;
private int maxDepth;
private int splitPoints; public void computeResidual(ArrayList<Integer> subId)
{
for(int i=0;i<subId.size();i++)
{
int idx=subId.get(i);
int y=0;
if(this.labelTrainData.get(idx)==-1) y=0;
else y=1;
double sum=Math.exp(this.F.get(idx).get(0))+Math.exp(this.F.get(idx).get(1));
double p1=Math.exp(this.F.get(idx).get(0))/sum,p2=Math.exp(this.F.get(idx).get(1))/sum;
this.residual.get(idx).set(0, y-p1);
this.residual.get(idx).set(1, y-p2);
}
}
public ArrayList<Integer> myrandom(int maxNum,int num)
{
ArrayList<Integer> ans=new ArrayList<>();
Set<Integer> mySet=new HashSet<>();
while(mySet.size()<num)
{
Random r=new Random();
int tmp=r.nextInt(maxNum);
mySet.add(tmp);
}
Iterator<Integer> it=mySet.iterator();
while(it.hasNext())
{
ans.add(it.next());
}
return ans;
} public GBDT()
{
this.max_iter=50;
this.sampleRate=0.8;
this.K=2;//2分类问题
this.maxDepth=6;
this.splitPoints=3;
this.learningRate=0.01;
getData();
} public void train()
{
for(int i=0;i<max_iter;i++)
{
ArrayList<Integer> subSet=new ArrayList<>();
int numSubset=(int)(this.n*this.sampleRate);
subSet=myrandom(this.n,numSubset);
computeResidual(subSet);
ArrayList<Double> target=new ArrayList<>();
ArrayList<Tree> tmpTree=new ArrayList<>();
int maxdepths[]={this.maxDepth};
for(int j=0;j<this.K;j++)
{
target.clear();
for(int k=0;k<subSet.size();k++)
{
target.add(residual.get(subSet.get(k)).get(j));
}
ArrayList<ArrayList<Integer>> leafNodes=new ArrayList<ArrayList<Integer>>();
ArrayList<Double> leafValues=new ArrayList<>();
Tree treeSub=new Tree();
Tree iterTree=treeSub.constructTree(leafNodes,leafValues,K,splitPoints, isDigit, subSet, trainData, target,maxdepths,0);
tmpTree.add(iterTree);
updateFvalue(isDigit, subSet,leafNodes,leafValues,j,iterTree);
} trees.add(tmpTree);
}
} public void updateFvalue(Boolean isDigit[], ArrayList<Integer> subIdx,ArrayList<ArrayList<Integer>> leafNodes,ArrayList<Double> leafValues,int label,Tree root)
{
ArrayList<Integer> remainIdx=new ArrayList<>();
int arr[]=new int[this.n];
for(int i=0;i<this.n;i++)
arr[i]=i;
for(int i=0;i<subIdx.size();i++)
{
arr[subIdx.get(i)]=-1;
}
//求出不是用来训练树的余下集合
for(int i=0;i<this.n;i++)
{
if(arr[i]!=-1)
remainIdx.add(i);
}
for(int i=0;i<leafNodes.size();i++)
{
for(int j=0;j<leafNodes.get(i).size();j++)
{
this.F.get(leafNodes.get(i).get(j)).set(label, this.F.get(leafNodes.get(i).get(j)).get(label)+this.learningRate*root.getPredictValue(root));
}
}
for(int i=0;i<remainIdx.size();i++)
{
double leafV=root.getPredictValue(root,this.trainData.get(remainIdx.get(i)),isDigit);
this.F.get(remainIdx.get(i)).set(label, this.F.get(remainIdx.get(i)).get(label)+this.learningRate*leafV);
} } public boolean checkDigit(String str) {
for(int i=0;i<str.length();i++)
{
if(!(str.charAt(i)>='0'&&str.charAt(i)<='9'))
{
return false;
}
}
return true;
} public void getData() {
Data d =new Data();
this.datas=d.getTrainData();
this.dim=this.datas.get(0).size()-1;
this.isDigit=new Boolean[this.dim];
//遍历所有样本,去掉中间含有不是正常的数据
for(int i=0;i<this.datas.get(0).size()-1;i++)
labelSets.add(this.datas.get(0).get(i));
//保证数据的第一行是正确的,来判断,特征哪些纬度是数字,哪些纬度是字符串
for(int i=0;i<this.dim;i++)
{
if(checkDigit(this.datas.get(0).get(i)))
this.isDigit[i]=true;
else this.isDigit[i]=false;
}
//如果字符串==?说明是异常数据,这里做数据的清理
for(int i=1;i<this.datas.size();i++)
{
ArrayList<String> tmp=new ArrayList<>();
boolean flag=true;
for(int j=0;j<this.dim;j++)
{
if(datas.get(i).get(j).trim().equals("?"))
{
flag=false;
break;
}
}
if(!flag) continue;
if(datas.get(i).get(this.dim).trim().equals("?")) continue;
trainData.add(tmp);
if(datas.get(i).get(this.dim).trim().equals("<=50K"))
labelTrainData.add(-1);
else
labelTrainData.add(1); }
this.n=this.labelTrainData.size(); for(int i=0;i<this.datas.get(0).size()-1;i++)
labelSets.add(this.datas.get(0).get(i)); //初始化F矩阵为全0,F矩阵是n*2,是2分类问题,如果要多分类,改下这里就可以了
for(int i=0;i<this.n;i++)
{
ArrayList<Double> arrTmp=new ArrayList<Double>();
for(int j=0;j<2;j++)
{
arrTmp.add(0.0);
}
this.F.add(arrTmp);
this.residual.add(arrTmp);
} } public static void main(String[] args) {
GBDT dGbdt=new GBDT();
dGbdt.getData();
System.err.println(dGbdt.n); }
}

  

上一篇:修改zend studio字体大小


下一篇:Python 迭代器和生成器(转)