解:首先有个套路是一条边的权值是[两端点颜色不同]。这个用树剖直接维护,支持修改。
每次询问建虚树,查询虚树上每条边的权值。然后树形DP,用开店的方法,每个点链加链查。
#include <bits/stdc++.h> #define forson(x, i) for(int i = e[x]; i; i = edge[i].nex) typedef long long LL;
const int N = ; struct Edge {
int nex, v;
LL len;
}edge[N << ], EDGE[N]; int tp, TP; int e[N], top[N], fa[N], son[N], siz[N], d[N], pos[N], id[N], num, val[N], n, imp2[N];
int sum[N << ], lc[N << ], rc[N << ], tag[N << ];
int imp[N], K, stk[N], Top, RT, Time, E[N], vis[N], use[N], DEEP[N];
LL SIZ[N], ans[N], D[N]; inline void add(int x, int y) {
tp++;
edge[tp].v = y;
edge[tp].nex = e[x];
e[x] = tp;
return;
} /// ------------------- tree 1 ------------------------- void DFS_1(int x, int f) { /// get fa son siz d
fa[x] = f;
siz[x] = ;
d[x] = d[f] + ;
forson(x, i) {
int y = edge[i].v;
if(y == f) continue;
DFS_1(y, x);
siz[x] += siz[y];
if(siz[y] > siz[son[x]]) {
son[x] = y;
}
}
return;
} void DFS_2(int x, int f) { /// get top pos id
top[x] = f;
pos[x] = ++num;
id[num] = x;
if(son[x]) DFS_2(son[x], f);
forson(x, i) {
int y = edge[i].v;
if(y == fa[x] || y == son[x]) continue;
DFS_2(y, y);
}
return;
} /// ------------------ seg 1 ---------------------- #define ls (o << 1)
#define rs (o << 1 | 1) inline void pushup(int o) {
lc[o] = lc[ls];
rc[o] = rc[rs];
sum[o] = sum[ls] + sum[rs] + (rc[ls] != lc[rs]);
return;
} inline void pushdown(int o) {
if(tag[o] != -) {
lc[ls] = rc[ls] = tag[ls] = tag[o];
lc[rs] = rc[rs] = tag[rs] = tag[o];
sum[ls] = sum[rs] = ;
tag[o] = -;
}
return;
} #undef ls
#undef rs void build(int l, int r, int o) {
if(l == r) {
lc[o] = rc[o] = val[id[r]];
sum[o] = ;
return;
}
int mid = (l + r) >> ;
build(l, mid, o << );
build(mid + , r, o << | );
pushup(o);
return;
} void change(int L, int R, int v, int l, int r, int o) {
if(L <= l && r <= R) {
lc[o] = rc[o] = tag[o] = v;
sum[o] = ;
return;
}
int mid = (l + r) >> ;
pushdown(o);
if(L <= mid) change(L, R, v, l, mid, o << );
if(mid < R) change(L, R, v, mid + , r, o << | );
pushup(o);
return;
} int ask(int p, int l, int r, int o) {
if(l == r) return lc[o];
int mid = (l + r) >> ;
pushdown(o);
if(p <= mid) return ask(p, l, mid, o << );
else return ask(p, mid + , r, o << | );
} int getSum(int L, int R, int l, int r, int o) {
if(L <= l && r <= R) {
return sum[o];
}
pushdown(o);
int mid = (l + r) >> ;
if(R <= mid) return getSum(L, R, l, mid, o << );
if(mid < L) return getSum(L, R, mid + , r, o << | );
return getSum(L, R, l, mid, o << ) + getSum(L, R, mid + , r, o << | ) + (rc[o << ] != lc[o << | ]);
} inline int lca(int x, int y) {
while(top[x] != top[y]) {
if(d[top[x]] < d[top[y]])
y = fa[top[y]];
else
x = fa[top[x]];
}
return d[x] < d[y] ? x : y;
} inline int getLen(int x, int z) {
//printf("getLen %d %d \n", x, z);
int col = ask(pos[x], , n, ), ans = ;
while(top[x] != top[z]) {
ans += (col != ask(pos[x], , n, ));
ans += getSum(pos[top[x]], pos[x], , n, );
//printf("x = %d top[x] = %d col = %d ans = %d \n", x, top[x], col, ans);
col = ask(pos[top[x]], , n, );
x = fa[top[x]];
}
ans += (col != ask(pos[x], , n, ));
//printf("%d != %d \n", col, ask(pos[x], 1, n, 1));
ans += getSum(pos[z], pos[x], , n, );
//printf("return ans = %d \n", ans);
return ans;
} inline void Change(int x, int y, int v) {
while(top[x] != top[y]) {
if(d[top[x]] > d[top[y]]) {
change(pos[top[x]], pos[x], v, , n, );
x = fa[top[x]];
}
else {
change(pos[top[y]], pos[y], v, , n, );
y = fa[top[y]];
}
}
if(d[x] < d[y]) std::swap(x, y);
change(pos[y], pos[x], v, , n, );
return;
} /// ------------------- tree 2 ---------------------- inline void ADD(int x, int y) {
TP++;
EDGE[TP].v = y;
EDGE[TP].len = getLen(y, x);
//printf("getLen %d %d = %d \n", y, x, EDGE[TP].len);
EDGE[TP].nex = E[x];
E[x] = TP;
return;
} inline bool cmp(const int &a, const int &b) {
return pos[a] < pos[b];
} inline void work(int x) {
if(vis[x] == Time) return;
vis[x] = Time;
D[x] = E[x] = ;
return;
} inline void build_t() {
TP = ;
memcpy(imp + , imp2 + , K * sizeof(int));
std::sort(imp + , imp + K + , cmp);
stk[Top = ] = imp[];
work(imp[]);
for(int i = ; i <= K; i++) {
int x = imp[i], y = lca(x, stk[Top]);
work(x); work(y);
while(Top > && d[y] <= d[stk[Top - ]]) {
ADD(stk[Top - ], stk[Top]);
Top--;
}
if(y != stk[Top]) {
ADD(y, stk[Top]);
stk[Top] = y;
}
stk[++Top] = x;
}
while(Top > ) {
ADD(stk[Top - ], stk[Top]);
Top--;
}
RT = stk[Top];
return;
} void dfs_1(int x) { /// DP 1
SIZ[x] = (use[x] == Time);
for(int i = E[x]; i; i = EDGE[i].nex) {
int y = EDGE[i].v;
dfs_1(y);
SIZ[x] += SIZ[y];
}
return;
} void dfs_2(int x) { /// DP 2
if(use[x] == Time) {
ans[x] = D[x];
}
for(int i = E[x]; i; i = EDGE[i].nex) {
int y = EDGE[i].v;
D[y] = D[x] + SIZ[y] * EDGE[i].len;
DEEP[y] = DEEP[x] + EDGE[i].len;
//printf("dfs_2 D %d = %lld * %lld = %lld \n", y, SIZ[y], EDGE[i].len, D[y]);
dfs_2(y);
}
return;
} inline void cal() {
build_t();
dfs_1(RT);
DEEP[RT] = ;
dfs_2(RT);
return;
} int main() {
memset(tag, -, sizeof(tag));
int q;
scanf("%d%d", &n, &q);
for(int i = ; i <= n; i++) {
scanf("%d", &val[i]);
}
for(int i = , x, y; i < n; i++) {
scanf("%d%d", &x, &y);
add(x, y); add(y, x);
}
DFS_1(, );
DFS_2(, );
build(, n, ); for(int i = , f, x, y, z; i <= q; i++) {
scanf("%d%d", &f, &x);
if(f == ) {
scanf("%d%d", &y, &z);
Change(x, y, z);
}
else {
Time++;
K = x;
for(int j = ; j <= K; j++) {
scanf("%d", &imp2[j]);
use[imp2[j]] = Time;
}
cal();
LL SUM = ;
for(int i = ; i <= K; i++) {
SUM += DEEP[imp2[i]];
//printf("D %d = %lld \n", imp2[i], D[imp2[i]]);
}
//printf("SUM = %lld \n", SUM);
for(int i = ; i <= K; i++) {
printf("%lld ", SUM + K * DEEP[imp2[i]] - * ans[imp2[i]] + K);
}
puts("");
}
}
return ;
}
AC代码