给定一棵大小为 \(n\) 的树,每个结点都有颜色。
定义 \(s(i, j)\) 为从 \(i\) 到 \(j\) 的不同颜色数量以及 \(sum_i = \sum\limits_{j= 1}^ns(i,j)\)。
求出所有的 \(sum_i\)。
较为复杂的点分治题。
也可以用差分 \(O(n)\) 解决。
#include <cstdio>
#include <algorithm>
inline int read(void){
int res = 0;
char ch = std::getchar();
while(ch < '0' || ch > '9')
ch = std::getchar();
while(ch >= '0' && ch <= '9')
res = res * 10 + ch - 48, ch = std::getchar();
return res;
}
typedef long long ll;
const int MAXN = 1e5 + 19;
struct Edge{
int to, next;
}edge[MAXN << 1];
int head[MAXN], cnt;
inline void add(int from, int to){
edge[++cnt].to = to;
edge[cnt].next = head[from];
head[from] = cnt;
}
bool vist[MAXN];
namespace gravity{
int n, size[MAXN], min_size, root;
void dfs0(int node, int f){
size[node] = 1;
for(int i = head[node]; i; i = edge[i].next)
if(edge[i].to != f && !vist[edge[i].to]){
dfs0(edge[i].to, node);
size[node] += size[edge[i].to];
}
}
void dfs1(int node, int f){
int g = n - size[node];
for(int i = head[node]; i; i = edge[i].next)
if(edge[i].to != f && !vist[edge[i].to]){
dfs1(edge[i].to, node);
g = std::max(g, size[edge[i].to]);
}
if(g < min_size)
min_size = g, root = node;
}
int getroot(int node){
dfs0(node, -1);
n = size[node], min_size = 0x3f3f3f3f;
dfs1(node, -1);
return root;
}
}
using gravity::getroot;
int n, c[MAXN];
ll ans[MAXN];
ll sum, dp[MAXN];
bool counted[MAXN];
int size[MAXN];
int stack[MAXN], top;
void dfs0(int node, int f){
size[node] = 1; stack[++top] = c[node];
for(int i = head[node]; i; i = edge[i].next)
if(edge[i].to != f && !vist[edge[i].to]){
dfs0(edge[i].to, node);
size[node] += size[edge[i].to];
}
}
void dfs1(int node, int f){
bool flag = false;
if(!counted[c[node]]){
sum += size[node];
dp[c[node]] += size[node];
counted[c[node]] = true;
flag = true;
}
for(int i = head[node]; i; i = edge[i].next)
if(edge[i].to != f && !vist[edge[i].to])
dfs1(edge[i].to, node);
if(flag)
counted[c[node]] = false;
}
void dfs2(int node, int f){
bool flag = false;
if(!counted[c[node]]){
sum -= size[node];
dp[c[node]] -= size[node];
counted[c[node]] = true;
flag = true;
}
for(int i = head[node]; i; i = edge[i].next)
if(edge[i].to != f && !vist[edge[i].to])
dfs2(edge[i].to, node);
if(flag)
counted[c[node]] = false;
}
ll t, k;
bool bra[MAXN];
void dfs3(int node, int f){
sum -= dp[c[node]];
ll tmp = dp[c[node]];
dp[c[node]] = 0;
bool flag = false;
if(!bra[c[node]])
bra[c[node]] = true,
flag = true,
++t;
ans[node] += k * t + sum;
for(int i = head[node]; i; i = edge[i].next)
if(edge[i].to != f && !vist[edge[i].to])
dfs3(edge[i].to, node);
dp[c[node]] = tmp;
sum += dp[c[node]];
if(flag)
bra[c[node]] = false,
--t;
}
void solve(int node){
node = getroot(node);
vist[node] = 1;
dfs0(node, -1);
counted[c[node]] = true;
sum += size[node];
dp[c[node]] += size[node];
for(int i = head[node]; i; i = edge[i].next)
if(!vist[edge[i].to])
dfs1(edge[i].to, node);
ans[node] += sum;
for(int i = head[node]; i; i = edge[i].next)
if(!vist[edge[i].to]){
k = size[node] - size[edge[i].to];
sum -= size[edge[i].to],
dp[c[node]] -= size[edge[i].to];
dfs2(edge[i].to, node);
dfs3(edge[i].to, node);
dfs1(edge[i].to, node);
sum += size[edge[i].to],
dp[c[node]] += size[edge[i].to];
}
counted[c[node]] = false;
sum = 0ll;
while(top)
dp[stack[top--]] = 0;
for(int i = head[node]; i; i = edge[i].next)
if(!vist[edge[i].to])
solve(edge[i].to);
}
int main(){
n = read();
for(int i = 1; i <= n; ++i)
c[i] = read();
for(int i = 2; i <= n; ++i){
int u = read(), v = read();
add(u, v), add(v, u);
}
solve(1);
for(int i = 1; i <= n; ++i)
std::printf("%lld\n", ans[i]);
return 0;
}