P2664 树上游戏

给定一棵大小为 \(n\) 的树,每个结点都有颜色。
定义 \(s(i, j)\) 为从 \(i\) 到 \(j\) 的不同颜色数量以及 \(sum_i = \sum\limits_{j= 1}^ns(i,j)\)。
求出所有的 \(sum_i\)。

较为复杂的点分治题。
也可以用差分 \(O(n)\) 解决。

#include <cstdio>
#include <algorithm>

inline int read(void){
	int res = 0;
	char ch = std::getchar();
	while(ch < '0' || ch > '9')
		ch = std::getchar();
	while(ch >= '0' && ch <= '9')
		res = res * 10 + ch - 48, ch = std::getchar();
	return res;
}

typedef long long ll;

const int MAXN = 1e5 + 19;

struct Edge{
	int to, next;
}edge[MAXN << 1];

int head[MAXN], cnt;

inline void add(int from, int to){
	edge[++cnt].to = to;
	edge[cnt].next = head[from];
	head[from] = cnt;
}

bool vist[MAXN];

namespace gravity{
	int n, size[MAXN], min_size, root;
	
	void dfs0(int node, int f){
		size[node] = 1;
		for(int i = head[node]; i; i = edge[i].next)
			if(edge[i].to != f && !vist[edge[i].to]){
				dfs0(edge[i].to, node);
				size[node] += size[edge[i].to];
			}
	}
	
	void dfs1(int node, int f){
		int g = n - size[node];
		for(int i = head[node]; i; i = edge[i].next)
			if(edge[i].to != f && !vist[edge[i].to]){
				dfs1(edge[i].to, node);
				g = std::max(g, size[edge[i].to]);
			}
		if(g < min_size)
			min_size = g, root = node;
	}
	
	int getroot(int node){
		dfs0(node, -1);
		n = size[node], min_size = 0x3f3f3f3f;
		dfs1(node, -1);
		return root;
	}
}

using gravity::getroot;

int n, c[MAXN];

ll ans[MAXN];
ll sum, dp[MAXN];
bool counted[MAXN];

int size[MAXN];
int stack[MAXN], top;

void dfs0(int node, int f){
	size[node] = 1; stack[++top] = c[node];
	for(int i = head[node]; i; i = edge[i].next)
		if(edge[i].to != f && !vist[edge[i].to]){
			dfs0(edge[i].to, node);
			size[node] += size[edge[i].to];
		}
}

void dfs1(int node, int f){
	bool flag = false;
	if(!counted[c[node]]){
		sum += size[node];
		dp[c[node]] += size[node];
		counted[c[node]] = true;
		flag = true;
	}
	for(int i = head[node]; i; i = edge[i].next)
		if(edge[i].to != f && !vist[edge[i].to])
			dfs1(edge[i].to, node);
	if(flag)
		counted[c[node]] = false;
}

void dfs2(int node, int f){
	bool flag = false;
	if(!counted[c[node]]){
		sum -= size[node];
		dp[c[node]] -= size[node];
		counted[c[node]] = true;
		flag = true;
	}
	for(int i = head[node]; i; i = edge[i].next)
		if(edge[i].to != f && !vist[edge[i].to])
			dfs2(edge[i].to, node);
	if(flag)
		counted[c[node]] = false;
}

ll t, k;

bool bra[MAXN];

void dfs3(int node, int f){
	sum -= dp[c[node]];
	ll tmp = dp[c[node]];
	dp[c[node]] = 0;
	bool flag = false;
	if(!bra[c[node]])
		bra[c[node]] = true,
		flag = true,
		++t;
	ans[node] += k * t + sum;
	for(int i = head[node]; i; i = edge[i].next)
		if(edge[i].to != f && !vist[edge[i].to])
			dfs3(edge[i].to, node);
	dp[c[node]] = tmp;
	sum += dp[c[node]];
	if(flag)
		bra[c[node]] = false,
		--t;
}

void solve(int node){
	node = getroot(node);
	vist[node] = 1;
	dfs0(node, -1);
	counted[c[node]] = true;
	sum += size[node];
	dp[c[node]] += size[node];
	for(int i = head[node]; i; i = edge[i].next)
		if(!vist[edge[i].to])
			dfs1(edge[i].to, node);
	ans[node] += sum;
	for(int i = head[node]; i; i = edge[i].next)
		if(!vist[edge[i].to]){
			k = size[node] - size[edge[i].to];
			sum -= size[edge[i].to],
			dp[c[node]] -= size[edge[i].to];
			dfs2(edge[i].to, node);
			dfs3(edge[i].to, node);
			dfs1(edge[i].to, node);
			sum += size[edge[i].to],
			dp[c[node]] += size[edge[i].to]; 
		}
	counted[c[node]] = false;
	sum = 0ll;
	while(top)
		dp[stack[top--]] = 0;
	for(int i = head[node]; i; i = edge[i].next)
		if(!vist[edge[i].to])
			solve(edge[i].to);
}

int main(){
	n = read();
	for(int i = 1; i <= n; ++i)
		c[i] = read();
	for(int i = 2; i <= n; ++i){
		int u = read(), v = read();
		add(u, v), add(v, u);
	}
	solve(1);
	for(int i = 1; i <= n; ++i)
		std::printf("%lld\n", ans[i]);
	return 0;
}
上一篇:P2664 树上游戏


下一篇:vim操作