题目链接:http://poj.org/problem?id=1988
题意:有n个元素,开始每个元素各自在一个栈中,有两种操作,将含有元素x的栈放在含有y的栈的顶端,合并为一个栈。 第二种操作是询问含有x元素下面有多少个元素。 思路: 并查集,把每一堆看作一个栈,堆的下方看作栈顶。因为当我们知道栈中元素的总数,和某元素到“栈顶”的距离, 我们就能知道这个元素下面有多少元素。合并操作的时候,始终使用在下面栈的根来做合并之后的根,这样也就达到了栈中的根是栈中的“栈顶”元素的效果,我们只需在每个“栈顶”中记录该栈中的元素总数即可。然而我们还需要得知某元素到“栈顶”的距离,这样我们就需要记录每个元素到其父亲的距离,把它到根上所经过的距离加起来,即为它到“栈顶”的距离。这样我们就得出了结果。 这个图是在合并的时候的关键部分 另外,在进行Find()操作时,有一句 under[x] += under[t[x].parent];这句话就是在递归寻找根结点时,计算出每个元素距离栈底(根)的距离。1 #include<iostream> 2 #include<stdio.h> 3 #include<cstring> 4 #include<cmath> 5 #include<vector> 6 #include<stack> 7 #include<map> 8 #include<set> 9 #include<list> 10 #include<queue> 11 #include<string> 12 #include<algorithm> 13 #include<iomanip> 14 using namespace std; 15 16 struct node 17 { 18 int parent; 19 int date; 20 }; 21 22 int * total; 23 int * under; 24 25 class DisJoinSet 26 { 27 protected: 28 int n; 29 30 node * tree; 31 public: 32 DisJoinSet(int n); 33 ~DisJoinSet(); 34 void Init(); 35 int Find(int x); 36 void Union(int x,int y); 37 }; 38 39 DisJoinSet::DisJoinSet(int n) 40 { 41 this->n = n; 42 tree = new node[n+2]; 43 total = new int[n+2]; 44 under = new int[n+2]; 45 Init(); 46 } 47 DisJoinSet::~DisJoinSet() 48 { 49 delete[] under; 50 delete[] total; 51 delete[] tree; 52 } 53 54 void DisJoinSet::Init() 55 { 56 for(int i = 1;i <= n ;i ++) 57 { 58 tree[i].date = i; 59 tree[i].parent = i; 60 total[i] = 1; 61 under[i] = 0; 62 } 63 } 64 int DisJoinSet::Find(int x) 65 { 66 //int temp = tree[x].parent; 67 if(x != tree[x].parent) 68 { 69 int par = Find(tree[x].parent); 70 under[x] += under[tree[x].parent];//把父亲结点下面的个数加到自己头上 71 tree[x].parent = par; 72 return tree[x].parent; 73 } 74 else 75 { 76 return x; 77 } 78 } 79 80 void DisJoinSet::Union(int x,int y) 81 { 82 int pa = Find(x); 83 int pb = Find(y); 84 if(pa == pb)return ; 85 else 86 { 87 tree[pa].parent = pb;//x的根变为y的根 即把x所在的堆放在y所在的堆上面 88 under[pa] = total[pb];//pa下的数量即原来y所在栈里的元素total 89 total[pb] += total[pa];//更新y的totoal 90 } 91 } 92 93 int main() 94 { 95 int p; 96 while(scanf("%d",&p) != EOF) 97 { 98 if(p == 0)break; 99 DisJoinSet dis(p); 100 char s1[2]; 101 for(int i = 0 ;i < p ;i++) 102 { 103 104 int s2; 105 int s3; 106 scanf("%s",s1); 107 if(s1[0] == 'M') 108 { 109 scanf("%d%d",&s2,&s3); 110 int pa = dis.Find(s2); 111 int pb = dis.Find(s3); 112 if(pa != pb) 113 { 114 dis.Union(s2,s3); 115 } 116 } 117 if(s1[0] == 'C') 118 { 119 scanf("%d",&s2); 120 dis.Find(s2); 121 cout<<under[s2]<<endl; 122 } 123 } 124 dis.~DisJoinSet(); 125 } 126 return 0; 127 }View Code