题目链接
题意:
树上每个结点有一个颜色,统计树上每个结点的子树上出现次数最多的颜色和。
思路:
树上启发式合并,将原本的\(O(n^2)\)复杂度变为\(O(n*logn)\),OIwiki的复杂度证明。
code:
#include <iostream>
#include <cstdio>
#include <string>
#include <cstring>
#include <algorithm>
#include <queue>
#include <vector>
#include <deque>
#include <cmath>
#include <ctime>
#include <map>
#include <set>
// #include <unordered_map>
#define fi first
#define se second
#define pb push_back
#define endl "\n"
#define debug(x) cout << #x << ":" << x << endl;
#define bug cout << "********" << endl;
#define all(x) x.begin(), x.end()
#define lowbit(x) x & -x
#define fin(x) freopen(x, "r", stdin)
#define fout(x) freopen(x, "w", stdout)
#define ull unsigned long long
#define ll long long
const double eps = 1e-15;
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const double pi = acos(-1.0);
const int mod = 1e9 + 7;
const int maxn = 2e5 + 10;
using namespace std;
int L[maxn], R[maxn];
int col[maxn], dfn, num[maxn], cnt[maxn];
int sz[maxn], n, bson[maxn], sum[maxn], maxx;
ll ans[maxn], summ[maxn];
vector<int> v[maxn];
void add(int u){
int &a = cnt[col[u]];
sum[a] --, sum[a + 1] ++;
summ[a] -= col[u], summ[a + 1] += col[u];
cnt[col[u]] ++;
maxx = max(maxx, cnt[col[u]]);
}
void del(int u){
int &a = cnt[col[u]];
sum[a] --, sum[a - 1] ++;
summ[a] -= col[u], summ[a - 1] += col[u];
if(cnt[col[u]] == maxx && !sum[a])maxx --;
cnt[col[u]] --;
}
void dfs1(int u, int fa){
sz[u] = 1;
L[u] = ++ dfn;
num[dfn] = u;
for(int to : v[u]){
if(to != fa){
dfs1(to, u), sz[u] += sz[to];
if(!bson[u] || sz[bson[u]] < sz[to])bson[u] = to;
}
}
R[u] = dfn;
}
void dfs2(int u, int fa, bool flag){
for(int to : v[u]){
if(to != bson[u] && to != fa)dfs2(to, u, false);
}
if(bson[u])dfs2(bson[u], u, true);
for(int to : v[u]){
if(to != bson[u] && to != fa){
for(int i = L[to]; i <= R[to]; i ++)add(num[i]);
}
}
add(u);
ans[u] = summ[maxx];
if(!flag){
for(int i = L[u]; i <= R[u]; i ++)del(num[i]);
}
}
int main(){
scanf("%d", &n);
for(int i = 1; i <= n; i ++)scanf("%d", &col[i]);
for(int i = 1, a, b; i < n; i ++){
scanf("%d%d", &a, &b);
v[a].pb(b), v[b].pb(a);
}
dfs1(1, 0);
dfs2(1, 0, false);
for(int i = 1; i <= n; i ++)printf("%lld ", ans[i]);
return 0;
}