前置芝士:换根 dp
首先考虑不换根的做法,设 \(dp1_u\) 表示在 \(u\) 的子树里走且需要回到 \(u\) 所能获得的最大收益,\(dp2_u\) 表示在 \(u\) 的子树里走且不需要回到 \(u\) 所能获得的最大收益。
显然,\(dp1_u = V_u + \displaystyle\sum_{u \to^w v} \max(dp1_v - 2w, 0)\)。
代码:
#include <stdio.h>
typedef long long ll;
typedef struct {
int nxt;
int end;
int dis;
} Edge;
int cnt;
int head[100007], max_val[100007], second_max_val[100007], v[100007], dp1[100007], dp2[100007], max_val_son[100007], second_max_val_son[100007], ans[100007];
Edge edge[200007];
inline void init(int n){
cnt = 0;
for (int i = 1; i <= n; i++){
head[i] = max_val[i] = second_max_val[i] = 0;
}
}
inline int read(){
int ans = 0;
char ch = getchar();
while (ch < ‘0‘ || ch > ‘9‘){
ch = getchar();
}
while (ch >= ‘0‘ && ch <= ‘9‘){
ans = ans * 10 + (ch ^ 48);
ch = getchar();
}
return ans;
}
inline void add_edge(int start, int end, int dis){
cnt++;
edge[cnt].nxt = head[start];
head[start] = cnt;
edge[cnt].end = end;
edge[cnt].dis = dis;
}
inline int max(int a, int b){
return a > b ? a : b;
}
inline void update(int u, int v, int val){
if (max_val[u] < val){
second_max_val[u] = max_val[u];
second_max_val_son[u] = max_val_son[u];
max_val[u] = val;
max_val_son[u] = v;
} else if (second_max_val[u] < val){
second_max_val[u] = val;
second_max_val_son[u] = v;
}
}
void dfs1(int u, int father){
dp1[u] = v[u];
for (int i = head[u]; i != 0; i = edge[i].nxt){
int x = edge[i].end;
if (x != father){
int t;
dfs1(x, u);
t = max(dp1[x] - edge[i].dis * 2, 0);
dp1[u] += t;
update(u, x, max(dp2[x] - edge[i].dis, 0) - t);
}
}
dp2[u] = dp1[u] + max_val[u];
}
void dfs2(int u, int father){
for (int i = head[u]; i != 0; i = edge[i].nxt){
int x = edge[i].end;
if (x != father){
int t = max(dp1[x] - edge[i].dis * 2, 0), y = dp1[u] - t, z;
if (max_val_son[u] == x){
z = dp2[u] + second_max_val[u] - max(dp2[x] - edge[i].dis, 0);
} else {
z = dp2[u] - t;
}
t = max(y - edge[i].dis * 2, 0);
dp1[x] += t;
ans[x] = dp2[x] + t - max_val[x];
update(x, u, max(z - edge[i].dis, 0) - t);
ans[x] += max_val[x];
dp2[x] = ans[x];
dfs2(x, u);
}
}
}
inline void write(int n){
if (n >= 10) write(n / 10);
putchar(n % 10 + ‘0‘);
}
int main(){
int t = read();
for (int i = 1; i <= t; i++){
int n = read();
init(n);
for (int j = 1; j <= n; j++){
v[j] = read();
}
for (int j = 1; j < n; j++){
int u = read(), v = read(), c = read();
add_edge(u, v, c);
add_edge(v, u, c);
}
dfs1(1, 0);
dfs2(1, 0);
ans[1] = dp2[1];
printf("Case #");
write(i);
printf(":\n");
for (int j = 1; j <= n; j++){
write(ans[j]);
putchar(‘\n‘);
}
}
return 0;
}