With $Dsu \ on \ tree$ we can answer queries of this type:
How many vertices in the subtree of vertex $v$ has some property in $O (n \log n)$ time (for all of the queries)?
这题写的是轻重儿子(重链剖分)版本的 $Dsu \ on \ tree$
具体流程如下:
每次先递归计算轻儿子,再单独递归重儿子,计算完后轻儿子的一些信息需要删掉,但是重儿子的信息无需删除,如此出解,相当于是优化了暴力的多余部分
每个节点会作为轻儿子被计算,重链剖分上垂直有 $\log n$ 条链,故复杂度 $O (n \log n)$
代码
#include <iostream>
#include <cstdio>
#include <cstring> using namespace std; typedef long long LL; const int MAXN = 1e05 + ;
const int MAXM = 1e05 + ;
const int MAXC = 1e05 + ; struct LinkedForwardStar {
int to; int next;
} ; LinkedForwardStar Link[MAXM << ];
int Head[MAXN]= {};
int size = ; void Insert (int u, int v) {
Link[++ size].to = v;
Link[size].next = Head[u]; Head[u] = size;
} int N;
int colour[MAXN]; int son[MAXN]= {};
int subsize[MAXN]= {};
void DFS (int root, int father) {
son[root] = - ;
subsize[root] = ;
for (int i = Head[root]; i; i = Link[i].next) {
int v = Link[i].to;
if (v == father)
continue;
DFS (v, root);
subsize[root] += subsize[v];
if (son[root] == - || subsize[v] > subsize[son[root]])
son[root] = v;
}
}
int vis[MAXN]= {};
int total[MAXC]= {};
int maxv = ;
LL sum = ;
void calc (int root, int father, int delta) { // 统计答案
total[colour[root]] += delta;
if (delta > && total[colour[root]] >= maxv) {
if (total[colour[root]] > maxv)
sum = , maxv = total[colour[root]];
sum += colour[root];
}
for (int i = Head[root]; i; i = Link[i].next) {
int v = Link[i].to;
if (v == father || vis[v])
continue;
calc (v, root, delta);
}
}
LL answer[MAXN]= {};
void Solve (int root, int father, int type) { // type表示是不是重儿子信息
for (int i = Head[root]; i; i = Link[i].next) {
int v = Link[i].to;
if (v == father || v == son[root])
continue;
Solve (v, root, );
}
if (~ son[root])
Solve (son[root], root, ), vis[son[root]] = ;
calc (root, father, );
answer[root] = sum;
if (~ son[root])
vis[son[root]] = ;
if (! type) // 如果是轻儿子信息就需删除
calc (root, father, - ), maxv = sum = ;
} int getnum () {
int num = ;
char ch = getchar (); while (! isdigit (ch))
ch = getchar ();
while (isdigit (ch))
num = (num << ) + (num << ) + ch - '', ch = getchar (); return num;
} int main () {
N = getnum ();
for (int i = ; i <= N; i ++)
colour[i] = getnum ();
for (int i = ; i < N; i ++) {
int u = getnum (), v = getnum ();
Insert (u, v), Insert (v, u);
}
DFS (, ), Solve (, , );
for (int i = ; i <= N; i ++) {
if (i > )
putchar (' ');
printf ("%lld", answer[i]);
}
puts (""); return ;
} /*
4
1 2 3 4
1 2
2 3
2 4
*/ /*
15
1 2 3 1 2 3 3 1 1 3 2 2 1 2 3
1 2
1 3
1 4
1 14
1 15
2 5
2 6
2 7
3 8
3 9
3 10
4 11
4 12
4 13
*/