很明显的要用树上差分或者树链剖分,但是要求 将某一条边的权值归0后的最大路径最短是多少 ,就发人深思了。
对于维护树上一条路径的大小,可以先用前缀和维护出从根到两个节点的路径大小和,再减去两点LCA的前缀和就可以了。我们可以先用倍增/树剖甚至是ST表求出LCA并预处理出每一条路径长,然后思考,我们怎么得知答案。
容易想到个鬼二分答案,对于check答案是否正确的操作,我们可以记录下路径的最大值,并将所有大于当前答案的路径给更新到树上(用树上差分或者前缀和),再统计出 所有被更新了总更新路径条数 的边的最大值。用路径最大值减去后面这个最大值,若差小于等于答案,则可行,否则不可行。
显然的证明:只要有一条大于当前答案的路径无法被某一条边给缩到小于答案以下,则答案不可行。
代码:
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #include <cmath> #define maxn 300005 int n,m,nex[2 * maxn],head[maxn],to[2 * maxn],edge[2 * maxn],tot,fa[maxn][35],sum[maxn],depth[maxn],tag[maxn],mt,cnt = 0,maxm,maxc,length[maxn],lc[maxn]; struct road{ int x,y; }rd[maxn]; void add(int x,int y,int z) { to[++tot] = y; edge[tot] = z; nex[tot] = head[x]; head[x] = tot; } int lca(int x,int y) { if (depth[x] < depth[y]) std::swap(x,y); for (int i = mt;i >= 0;i--) if (depth[fa[x][i]] >= depth[y]) x = fa[x][i]; if (x == y) return x; for (int i = mt;i >= 0;i--) if (fa[x][i] != fa[y][i]) x = fa[x][i],y = fa[y][i]; return fa[x][0]; } void dfs2(int x) { for (int i = head[x];i;i = nex[i]) { int y = to[i]; if (y == fa[x][0]) continue; dfs2(y); tag[x] += tag[y]; }//树上差分统计子树信息 } int check(int x) { maxc = cnt = 0; for (int i = 1;i <= n;i++) tag[i] = 0; for (int i = 1;i <= m;i++) { if (length[i] > x) { tag[rd[i].x]++,tag[rd[i].y]++,tag[lc[i]] -= 2,cnt++; }//树上差分 } dfs2(1); for (int i = 1;i <= n;i++) { if (tag[i] >= cnt) maxc = std::max(maxc,sum[i] - sum[fa[i][0]]); } if (maxm - maxc <= x) return 1; else return 0; } int divide() { int ans; int l = 0,r = maxm; while (l < r) { int mid = (l + r) / 2; if (check(mid)) r = mid; else l = mid + 1; } return l; } void dfs(int x,int f) { depth[x] = depth[f] + 1; fa[x][0] = f; for (int i = head[x];i;i = nex[i]) { int y = to[i]; if (y == f) continue; sum[y] = sum[x] + edge[i]; dfs(y,x); } } void deal() { for (int i = 1;(1 << (i - 1)) <= n;i++) for (int j = 1;j <= n;j++) { fa[j][i] = fa[fa[j][i - 1]][i - 1]; mt = i; } //倍增预处理 } int main() { scanf("%d%d",&n,&m); for (int i = 1;i < n;i++) { int x,y,z; scanf("%d%d%d",&x,&y,&z); add(x,y,z); add(y,x,z); } for (int i = 1;i <= m;i++) { scanf("%d%d",&rd[i].x,&rd[i].y); } dfs(1,0); deal(); for (int i = 1;i <= m;i++) { int p = lca(rd[i].x,rd[i].y); lc[i] = p; length[i] = sum[rd[i].x] + sum[rd[i].y] - 2 * sum[p]; maxm = std::max(maxm,length[i]);//统计出每条路径的LCA、长度、以及最长路径 } printf("%d\n",divide()); return 0; }