题目大意:给定一棵含有$n$个结点的树,每个结点有权值$p_i$。要求驻扎军队,一条边连接的两结点必须至少有一个驻扎军队。现在有$q$次询问,每次规定两个点$a,b$,分别要求它们必须驻扎/不驻扎$(0/1)$。问每次驻扎的最小费用。$n,q\leq 10^5$
------------------------
如果没有询问,那就是没有上司的舞会。设$f_{i,0/1}$表示$i$不选/选,以$i$为根的子树所花费的最小代价。有转移:
$f_{i,0}=\sum f_{j,1}$
$f_{i,1}=\sum \min(f_{j,0},f_{j,1})+p_i$
现在若有询问,我们先考虑暴力的做法,就是将它所要求的地方设成$inf/-inf$,然后每次都跑一遍树形DP。这样的复杂度是$O(nq)$的,能得到44pts。然而这样会产生很多冗余状态:发现强制要求$a,b$改变状态只会影响到$a-lca-b$这一条链。所以我们不妨考虑倍增,预处理出$f$,每次只处理$a$到$b$这条链。
设$g_{i,0/1}$表示整棵树去掉以$i$为根的子树,$i$不选/选的最小代价。有转移:
$g_{v,0}=g_{x,1}+f_{x,1}-\min(f_{v,0},f_{v,1})$
$g_{v,1}=\min(g_{x,0}+f_{x,0}-f_{v,1},g_{v,0})$
令$anc$表示$i$的$2^j$祖先,设$ff_{i,j,0/1,0/1}$表示$anc-i$上路径,$i$不选/选,$anc$不选/选的最小代价。通过枚举$2^{j-1}$祖先的状态进行转移。注意边界处理。
这样我们求出了$f,g,ff$,可以着手对询问的处理了。发现若$a$是$b$的祖先,那么直接倍增上去即可,最后加上$g_{a,x}$;若不为祖先-子孙关系,那么就都先倍增到$lca$的儿子处,然后枚举$lca$和儿子的两个状态取最小值即可。
时间复杂度$O((n+q)\log n)$。注意开$long\ long$
代码:
#include<iostream> #include<cstdio> #include<cstring> #define int long long using namespace std; const int N=100005; const int inf=1e18; char id[10]; int f[N][2],g[N][2],fa[N][21],ff[N][21][2][2],v[N],dep[N],n,m; int head[N],cnt; struct node { int next,to; }edge[N*2]; inline int read() { int x=0,f=1;char ch=getchar(); while(!isdigit(ch)){if (ch=='-')f=-1;ch=getchar();} while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();} return x*f; } inline void add(int from,int to) { edge[++cnt]=(node){head[from],to}; head[from]=cnt; } inline void dfs1(int now,int father) { fa[now][0]=father;dep[now]=dep[father]+1; f[now][1]=v[now]; for (int i=head[now];i;i=edge[i].next) { int to=edge[i].to; if (to==father) continue; dfs1(to,now); f[now][0]+=f[to][1]; f[now][1]+=min(f[to][0],f[to][1]); } } inline void dfs2(int now,int fa) { for (int i=head[now];i;i=edge[i].next) { int to=edge[i].to; if (to==fa) continue; g[to][0]=g[now][1]+f[now][1]-min(f[to][0],f[to][1]); g[to][1]=min(g[to][0],g[now][0]+f[now][0]-f[to][1]); dfs2(to,now); } } inline int solve(int x,int a,int y,int b) { if (dep[x]<dep[y]) swap(x,y),swap(a,b); int tx[2]={inf,inf},ty[2]={inf,inf}; int nx[2],ny[2]; tx[a]=f[x][a];ty[b]=f[y][b]; for (int i=19;i>=0;i--) { if (dep[fa[x][i]]>=dep[y]) { nx[0]=nx[1]=inf; for (int j=0;j<2;j++) for (int k=0;k<2;k++) nx[j]=min(nx[j],tx[k]+ff[x][i][k][j]); tx[0]=nx[0],tx[1]=nx[1],x=fa[x][i]; } } if (x==y) return tx[b]+g[x][b]; for (int i=19;i>=0;i--) { if (fa[x][i]!=fa[y][i]) { nx[0]=nx[1]=ny[0]=ny[1]=inf; for (int j=0;j<2;j++) for (int k=0;k<2;k++) nx[j]=min(nx[j],tx[k]+ff[x][i][k][j]), ny[j]=min(ny[j],ty[k]+ff[y][i][k][j]); tx[0]=nx[0],tx[1]=nx[1],x=fa[x][i]; ty[0]=ny[0],ty[1]=ny[1],y=fa[y][i]; } } int l=fa[x][0]; int ans0=f[l][0]-f[x][1]-f[y][1]+tx[1]+ty[1]+g[l][0]; int ans1=f[l][1]-min(f[x][0],f[x][1])-min(f[y][0],f[y][1])+min(tx[0],tx[1])+min(ty[0],ty[1])+g[l][1]; return min(ans0,ans1); } signed main() { n=read();m=read();scanf("%s",id); for (int i=1;i<=n;i++) v[i]=read(); for (int i=1;i<n;i++) { int x=read(),y=read(); add(x,y);add(y,x); } dfs1(1,0); dfs2(1,0); for (int i=1;i<=n;i++) { ff[i][0][0][0]=inf; ff[i][0][0][1]=f[fa[i][0]][1]-min(f[i][0],f[i][1]); ff[i][0][1][0]=f[fa[i][0]][0]-f[i][1]; ff[i][0][1][1]=f[fa[i][0]][1]-min(f[i][0],f[i][1]); } for (int j=1;j<=19;j++) for (int i=1;i<=n;i++) { int tmp=fa[i][j-1]; fa[i][j]=fa[tmp][j-1]; for (int u=0;u<2;u++) for (int v=0;v<2;v++) { ff[i][j][u][v]=inf; for (int w=0;w<2;w++) ff[i][j][u][v]=min(ff[i][j][u][v],ff[i][j-1][u][w]+ff[tmp][j-1][w][v]); } } while(m--) { int a=read(),x=read(),b=read(),y=read(); if (!x&&!y&&(fa[b][0]==a||fa[a][0]==b)){ printf("-1\n"); continue; } printf("%lld\n",solve(a,x,b,y)); } return 0; }