Solution
对每个节点维护一个权值线段树,线段树上就维护最大值以及最大值出现的节点。利用树上查分的思想,在 \(x\),\(y\) 位置 \(+1\),在 \(LCA(x, y)\) 位置 \(-1\),\(fa_{LCA(x, y)}\) 位置 \(-1\)。
更新完后不断把线段树向上合并。简单来说,当前线段树 = 该节点所有子树维护的线段树合并起来。
注意!输出的时候一定要及时更新 ans,因为合并过程中可能会改变当前 \(ans\)。
Code
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 6000005, Z = 1e5 + 23;
struct edge { int nxt, to; } e[N];
int n, m, tot = 0, cnt = 0, last = 0;
int head[N], dep[N], t[N], L[N], R[N], ans[N], maxx[N], f[N / 10][32], ANS[N];
void add(int x, int y) { e[++tot] = (edge) { head[x], y }; head[x] = tot; }
void dfs(int x, int Fa)
{
dep[x] = dep[Fa] + 1, f[x][0] = x, f[x][1] = Fa;
for(int i = 2; i <= 30; i++) f[x][i] = f[f[x][i - 1]][i - 1];
for(int i = head[x]; i; i = e[i].nxt) if(e[i].to != Fa) dfs(e[i].to, x);
}
int LCA(int a, int b)
{
if(dep[a] > dep[b]) swap(a, b);
for(int i = 30; i >= 0; i--) if(dep[f[b][i]] >= dep[a]) b = f[b][i];
if(a == b) return a;
for(int i = 30; i >= 0; i--)
if(f[b][i] != f[a][i]) b = f[b][i], a = f[a][i];
return f[a][1];
}
void push_up(int x)
{
if(maxx[L[x]] >= maxx[R[x]]) maxx[x] = maxx[L[x]], ans[x] = ans[L[x]];
else maxx[x] = maxx[R[x]], ans[x] = ans[R[x]];
}
int update(int pre, int l, int r, int x, int k)
{
if(!pre) pre = ++cnt;
if(l == r)
{
ans[pre] = l;
maxx[pre] += k;
return pre;
}
int mid = (l + r) >> 1;
if(x <= mid) L[pre] = update(L[pre], l, mid, x, k);
else R[pre] = update(R[pre], mid + 1, r, x, k);
push_up(pre);
return pre;
}
int merge(int u, int v, int l, int r)
{
if(!u) return v;
if(!v) return u;
if(l == r)
{
maxx[u] = maxx[u] + maxx[v];
ans[u] = l;
return u;
}
int mid = (l + r) >> 1;
L[u] = merge(L[u], L[v], l, mid);
R[u] = merge(R[u], R[v], mid + 1, r);
push_up(u); return u;
}
void solve(int x)
{
for(int i = head[x]; i; i = e[i].nxt)
{
int v = e[i].to;
if(v != f[x][1]) solve(v), t[x] = merge(t[x], t[v], 1, Z);
}
if(maxx[t[x]]) ANS[x] = ans[t[x]];
}
int main()
{
scanf("%d%d", &n, &m);
memset(f, 0, sizeof(f));
memset(t, 0, sizeof(t));
memset(head, 0, sizeof(head));
memset(maxx, 0, sizeof(maxx));
for(int x, y, i = 1; i < n; i++)
scanf("%d%d", &x, &y), add(x, y), add(y, x);
dep[0] = 0, dfs(1, 0);
for(int x, y, z, i = 1; i <= m; i++)
{
scanf("%d%d%d", &x, &y, &z);
int l = LCA(x, y);
t[x] = update(t[x], 1, Z, z, 1);
t[y] = update(t[y], 1, Z, z, 1);
t[l] = update(t[l], 1, Z, z, -1);
if(f[l][1]) t[f[l][1]] = update(t[f[l][1]], 1, Z, z, -1);
}
solve(1);
for(int i = 1; i <= n; i++) printf("%d\n", ANS[i]);
return 0;
}