ID3算法 决策树 C++实现

人工智能课的实验。

数据结构:多叉树

这个实验我写了好久,开始的时候从数据的读入和表示入手,写到递归建树的部分时遇到了瓶颈,更新样例集和属性集的办法过于繁琐;

于是参考网上的代码后重新写,建立决策树类,把属性集、样例集作为数据成员加入类中,并设立访问数组,这样每次更新属性集、样例集时只是标记访问数组的对应元素即可,不必实际拷贝。

主函数:

 1 #include "Decision_tree.h"
 2 using namespace std;
 3 int main()
 4 {
 5     int num_attr,num_example;
 6     char filename[30];
 7     cout << "请输入训练集文件名:" << endl;
 8     cin >> filename;
 9     freopen(filename, "r", stdin);//从样例文件读入训练内容
10     cin >> num_attr >> num_example;//读入属性个数、例子个数
11     Decision_tree my_tree=Decision_tree(num_attr,num_example);
12     fclose(stdin);
13     freopen("CON", "r", stdin);//重定向标准输入到控制台
14     my_tree.display_attr();
15     cout << "决策树已建成,按深度优先遍历结果如下:" << endl;
16     my_tree.traverse();
17     do{
18         cout << "请输入测试数据,格式:属性1值 属性2值..." << endl;
19         Example test;
20         for (int i = 0; i < num_attr; i++)
21             cin >> test.values[i];
22         int result = my_tree.judge(test);
23         if (result == 1) cout << "分类结果为P" << endl;
24         else if (result == -1) cout << "分类结果为N" << endl;
25         else if (result == -2) cout << "无法根据已有样例集判断" << endl;
26         cout << "继续吗?(y/n)";
27         fflush(stdin);
28     } while (getchar() == 'y');
29 }

属性结构体

struct Attribute//属性
{
    string name;
    int count;//属性值个数
    int number;//属性的秩
    string values[MAX_VAL];
};

样例结构体

struct Example//样例
{
    string values[MAX];
    int pn;
    Example(){ pn = 0; }//默认为未分类的
};

决策树的结点

typedef struct Node//树的结点
{
    Attribute attr;
    Node* children[MAX_VAL];
    int classification[MAX_VAL];
    Node(){}
}Node;

决策树类的实现

ID3算法 决策树 C++实现ID3算法 决策树 C++实现
  1 class Decision_tree//决策树
  2 {
  3     Node *root;
  4     Example e[MAX];//样例全集
  5     Attribute a[MAX_ATTR];//属性全集
  6     int num_attr, num_example;
  7     int visited_exams[MAX];//样例集的访问情况
  8     int visited_attrs[MAX_ATTR];//属性集的访问情况
  9     Node* recursive_build_tree(int left_e[], int left_a[])//递归建树
 10     {
 11         double max = 0;
 12         int max_attr=-1;
 13         for (int i = 0;i<num_attr;i++)
 14         {//求信息增益最大的属性
 15             if (left_a[i]) continue;
 16             double temp = Gain(left_e, i);
 17             if (max<temp)
 18             {
 19                 max = temp;
 20                 max_attr = i;
 21             }
 22         }
 23         if (max_attr == -1) return NULL;//已没有可判的属性,返回空指针
 24         //cout << a[max_attr].name << endl;
 25         //以这个属性为结点,以各属性值为分支递归建树
 26         int p = 0, n = 0;
 27         Node *new_node=new Node();
 28         new_node->attr = a[max_attr];
 29         for (int i = 0; i<a[max_attr].count;i++)
 30         {//遍历这个属性的所有属性值
 31             for (int j = 0; j < num_example;j++)
 32             {//得到第i个属性值的正反例总数
 33                 if (left_e[j]) continue;
 34                 if (!e[j].values[max_attr].compare(a[max_attr].values[i]))
 35                 {//例子和属性都是循秩访问的,所以向量元素的顺序不能变
 36                     if (e[j].pn) p++;
 37                     else n++;
 38                 }
 39             }
 40             //cout << a[max_attr].values[i] << " ";
 41             //cout << p << " " << n << endl;
 42             if (p && !n)//全是正例,不再分
 43             {
 44                 //cout << "P" << endl;
 45                 new_node->classification[i] = 1;
 46                 new_node->children[i] = NULL;
 47             }
 48             else if (n && !p)//全是反例,不再分
 49             {
 50                 //cout << "N" << endl;
 51                 new_node->classification[i] = -1;
 52                 new_node->children[i] = NULL;
 53             }
 54             else if (!p && !n)//例子集已空
 55             {
 56                 //cout << "none" << endl;
 57                 new_node->classification[i] = -2;//表示未训练到这种分类,无法判断
 58                 new_node->children[i] = NULL;
 59             }
 60             else//例子集不空,且尚未能区分正反,更新访问情况,递归
 61             {
 62                 new_node->classification[i] = 0;
 63                 left_a[max_attr] = 1;//更新属性访问情况
 64                 int left_e_next[MAX];//下一轮的例子集(为便于回溯,不修改原例子集)
 65                 for (int k = 0; k < num_example; k++)
 66                     left_e_next[k] = left_e[k];
 67                 for (int j = 0; j < num_example; j++)
 68                 {
 69                     if (left_e[j]) continue;
 70                     if (!e[j].values[max_attr].compare(a[max_attr].values[i]))
 71                         left_e_next[j] = 0;//属性值匹配的例子,入选下一轮例子集
 72                     else left_e_next[j] = 1;//属性值不匹配,筛除
 73                 }
 74                 new_node->children[i] = recursive_build_tree(left_e_next, left_a);//递归
 75                 left_a[max_attr] = 0;//恢复属性访问情况
 76             }
 77             p = 0;
 78             n = 0;
 79         }
 80         return new_node;
 81     }
 82     double I(int p, int n)
 83     {
 84         double a = p / (p + (double)n);
 85         double b = n / (p + (double)n);
 86         if (a == 0 || b == 0) return 0;
 87         return -a*log(a) / log(2) - b*log(b) / log(2);
 88     }
 89     double Gain(int left_e[], int cur_attr)//计算信息增益
 90     {
 91         int sum_p=0, sum_n=0;
 92         int p[10] = { 0 }, n[10] = { 0 };
 93         for (int i = 0; i < num_example; i++)
 94         {//求样例集的p,n
 95             if (left_e[i]) continue;
 96             if (e[i].pn) sum_p++;
 97             else sum_n++;
 98         }
 99         if (!sum_p && !sum_n)
100         {
101             //cout << "no more examples!" << endl;
102             return -1;//样例集是空集
103         }
104             
105         double sum_Ipn = I(sum_p, sum_n);
106         for (int i = 0; i < a[cur_attr].count; i++)
107         {//求第i个属性值的p,n
108             for (int j = 0; j < num_example; j++)
109             {
110                 if (left_e[j]) continue;
111                 if (!e[j].values[cur_attr].compare(a[cur_attr].values[i]))
112                     if (e[j].pn) p[i]++;
113                     else n[i]++;
114             }
115         }
116         double E = 0;
117         for (int i = 0; i < a[cur_attr].count; i++)//计算属性的期望
118             E += (p[i] + n[i])*I(p[i], n[i]);
119         E /= (sum_p + sum_n);
120         //cout << a[cur_attr].name <<sum_Ipn - E << endl;
121         return sum_Ipn - E;
122     }
123     void recursive_traverse(Node *current)//DFS递归遍历
124     {
125         if (current == NULL) return;
126         cout << current->attr.name << endl;
127         for (int i = 0; i < current->attr.count; i++)
128         {
129             cout << current->attr.values[i] << " " << current->classification[i] << endl;
130             recursive_traverse(current->children[i]);
131         }
132     }
133     int recursive_judge(Example exa, Node *current)
134     {
135         for (int i = 0; i < current->attr.count; i++)
136         {
137             if (!exa.values[current->attr.number].compare(current->attr.values[i]))
138             {
139                 if (current->children[i]==NULL) return current->classification[i];
140                 else return recursive_judge(exa, current->children[i]);        
141             }        
142         }
143         return 0;
144     }
145 public:
146     Decision_tree(int num1,int num2)
147     {
148         
149         //通过读文件初始化
150         num_attr = num1;
151         num_example = num2;
152 
153         for (int i = 0; i<num_attr; i++)
154         {
155             a[i].number = i;//属性的秩
156             cin>>a[i].name;//读入属性名
157             cin>>a[i].count;//读入此属性的属性值个数
158             for (int j = 0; j<a[i].count; j++)
159             {
160                 cin>>a[i].values[j];//读入各属性值
161             }
162         }
163         
164         for (int i = 0; i<num_example; i++)
165         {
166             string temp;
167             for (int j = 0; j < num_attr; j++)
168             {
169                 cin>>e[i].values[j];
170             }
171             cin >> temp;
172             if (!temp.compare("P")) e[i].pn = 1;
173             else e[i].pn = 0;
174         }
175         //检查
176         /*for (int i = 0; i<num_attr; i++)
177         {
178             cout << a[i].name << endl;//读入属性名
179             for (int j = 0; j<a[i].count; j++)
180             {
181                 cout<<a[i].values[j]<<" ";//读入各属性值
182             }
183             cout << endl;
184         }
185         for (int i = 0; i<num_example; i++)
186         {
187             for (int j = 0; j < num_attr; j++)
188                 cout<<e[i].values[j]<<" ";
189             cout<<e[i].pn<<endl;
190             
191         }
192         */
193         memset(visited_exams, 0, sizeof(visited_exams));
194         memset(visited_attrs, 0, sizeof(visited_attrs));
195         root = recursive_build_tree(visited_exams,visited_attrs);
196     }
197     void traverse()
198     {
199         recursive_traverse(root);
200     }
201     int judge(Example exa)//判断
202     {
203         int result=recursive_judge(exa,root);
204         return result;
205     }
206     void display_attr()//显示属性
207     {
208         cout << "There are " << num_attr << " attributes, they are" << endl;
209         for (int i = 0; i < num_attr; i++)
210         {
211             cout << "[" << a[i].name << "]" << endl;
212             for (int j = 0; j < a[i].count; j++)
213                 cout << a[i].values[j] << " ";
214             cout << endl;
215         }
216     }
217 };
Decision_tree

现在这个版本的代码用了10小时完成,去检查时被研究生贬得一文不值。。。也的确,现在我们写的实验题目面向的都是规模非常小的问题,自然体会不到自己的代码在大数据面前的劣势。不过我现在确实学得太少了,很多数据结构都没有动手实现过,算法也是。对C++也只能算入了门。俗话说“磨刀不误砍柴工”,“工欲善其事,必先利其器”,先把基础知识学好,多做基本练习,学到的数据结构和算法都动手实现一遍,这样遇到实际问题也好对应到合适的数据结构和算法。

另外,参照一本书学习好的代码风格和习惯也是很重要的,因为写代码的习惯是思维习惯的反映,而我现在还处于初学者阶段,按照一种典型的流派模仿,构建起自己的思维模式后再谈其他的。

忽然觉得自己学了快两年编程还这么水实在是不能忍,都怪大一时年少不懂事没好好学基础。。。

不过,“悟已往之不谏,知来者之可追”,有了方向,一步步走下去就好,不求优于别人,但一定要“优于过去的自己”。

 

上一篇:用试探回溯法解决N皇后问题


下一篇:JQuery > 创建方法(函数)方法