题目:https://www.lydsy.com/JudgeOnline/problem.php?id=3572
关于虚树:https://www.cnblogs.com/zzqsblog/p/5560645.html
构造方法:
先把关键点按 dfs 序排序,然后依次插入树中;
插入当前点 cr 的时候,求 lca = get_lca( cr , sta[top] ) ;如果 dep[ sta[top] ] >= dep[lca] ,就一直弹栈;
弹栈结束后,看看现在的 sta[ top ] 是不是就是 lca 了,如果不是,就 sta[ ++ top ] = lca ;同时 fa[ sta[top+1] ] = lca , fa[ lca ] = sta[ top ] ;
把 cr 也加入栈中,即 sta[++top] = cr , fa[ cr ] = lca 。
sta[ 1 ] 就是虚树的根。
关于这道题:http://hzwer.com/6804.html
建好虚树,先换根 dp 得出虚树上的每个点应该被哪个点控制。换根的时候不用去掉该子树的贡献,因为不会有影响。
然后枚举虚树上的每条边 ( cr , fa ),用倍增在边上找到最浅的应该 “被控制 cr 的点控制” 的点 v ,然后 siz[v] - siz[cr] 和 siz[tv] - siz[v] 分别贡献即可,其中 tv 是 fa 在 v 方向的直接孩子。
关于不是虚树的点也不在虚树边上的那些点,自己的方法是在换根 dp 的时候处理;那个时候枚举孩子 v 的时候通过找 tv ,可以知道每个虚树上的点 cr 的不在虚树上的孩子们的 siz 和,直接贡献给控制 cr 的那个点即可。
#include<cstdio> #include<cstring> #include<algorithm> #define mkp make_pair #define fir first #define sec second using namespace std; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } const int N=3e5+5,K=18; int n,hd[N],xnt,to[N<<1],nxt[N<<1]; int dfn[N],dep[N],pre[N][K+5],bin[K+5],lg[N],siz[N],sm[N]; pair<int,int> dp[N]; int m,ans[N],tt,sta[N],tot,fa[N],h2[N],xt2,t2[N<<1],nt2[N<<1]; struct Node{int v,id;}q[N]; bool cmp(Node x,Node y){return dfn[x.v]<dfn[y.v];} void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;} void ad2(int x,int y){t2[++xt2]=y;nt2[xt2]=h2[x];h2[x]=xt2;} void ini_dfs(int cr,int fa) { dfn[cr]=++tot; dep[cr]=dep[fa]+1; pre[cr][0]=fa; siz[cr]=1; for(int t=1,u;(u=pre[pre[cr][t-1]][t-1]);t++) pre[cr][t]=u; for(int i=hd[cr],v;i;i=nxt[i]) if((v=to[i])!=fa) { ini_dfs(v,cr); siz[cr]+=siz[v]; } } int get_lca(int x,int y) { if(dep[x]<dep[y])swap(x,y); for(int t=lg[dep[x]-dep[y]];t>=0;t--) if(dep[pre[x][t]]>=dep[y]) x=pre[x][t]; if(x==y)return x; for(int t=lg[dep[x]];t>=0;t--) if(pre[x][t]!=pre[y][t]) x=pre[x][t],y=pre[y][t]; return pre[x][0]; } void build() { sort(q+1,q+m+1,cmp); tt=m; tot=0; for(int i=1;i<=m;i++) { int cr=q[i].v; if(!tot){sta[++tot]=cr;dp[cr]=mkp(0,cr);continue;} int lca=get_lca(cr,sta[tot]); while(dep[sta[tot]]>dep[lca])tot--; if(sta[tot]!=lca) { q[++tt].v=lca; fa[sta[tot+1]]=lca; fa[lca]=sta[tot]; sta[++tot]=lca; dp[lca]=mkp(N,0); } fa[cr]=lca; sta[++tot]=cr; dp[cr]=mkp(0,cr); } for(int i=1;i<=tt;i++)if(fa[q[i].v])ad2(fa[q[i].v],q[i].v); } void dfs(int cr,int fa) { for(int i=h2[cr],v;i;i=nt2[i]) if((v=t2[i])!=fa) { dfs(v,cr); int tp1=dp[v].fir+dep[v]-dep[cr],tp2=dp[cr].fir; if(tp1<tp2||(tp1==tp2&&dp[v].sec<dp[cr].sec)) dp[cr].fir=tp1,dp[cr].sec=dp[v].sec; } } int fnd2(int cr,int fa) { int d=dep[cr]-dep[fa]-1; while(d) { int lbt=(d&-d); cr=pre[cr][lg[lbt]]; d-=lbt; } return cr; } void dfsx(int cr,int fa) { int tp1=dp[fa].fir+dep[cr]-dep[fa],tp2=dp[cr].fir; if(fa&&(tp1<tp2||(tp1==tp2&&dp[fa].sec<dp[cr].sec))) dp[cr].fir=tp1,dp[cr].sec=dp[fa].sec; int s=siz[cr]; for(int i=h2[cr],v;i;i=nt2[i]) if((v=t2[i])!=fa) { int tv=fnd2(v,cr); s-=siz[tv]; dfsx(v,cr); } sm[dp[cr].sec]+=s-1;//-1 for zj } int fnd(int cr,int fa) { bool fg=(dp[cr].sec<dp[fa].sec); int x=dp[cr].fir,y=dp[fa].fir,d1=dep[cr],d2=dep[fa]; for(int t=lg[dep[cr]-dep[fa]];t>=0;t--) { int d=dep[pre[cr][t]]; int u=d1-d+x,v=d-d2+y; if(u<v||(u==v&&fg))cr=pre[cr][t]; } return cr; } bool In(int cr,int fa){return dfn[cr]>=dfn[fa]&&dfn[cr]<dfn[fa]+siz[fa];} void solve() { for(int i=1;i<=m;i++)sm[q[i].v]=0; int rt=sta[1]; dfs(rt,0); dfsx(rt,0); for(int i=1;i<=tt;i++)sm[dp[q[i].v].sec]++; for(int i=1;i<=tt;i++) { int cr=q[i].v,f=fa[cr]; if(!f) {sm[dp[cr].sec]+=n-siz[cr];continue;} int v=fnd(cr,f),tv=fnd2(cr,f); sm[dp[cr].sec]+=siz[v==f?tv:v]-siz[cr];// sm[dp[f].sec]+=siz[tv]-siz[v==f?tv:v];//tv// } for(int i=1;i<=m;i++)ans[q[i].id]=sm[q[i].v]; for(int i=1;i<=m;i++)printf("%d ",ans[i]); puts(""); } int main() { n=rdn(); for(int i=1,u,v;i<n;i++) u=rdn(),v=rdn(),add(u,v),add(v,u); bin[0]=1;for(int i=1;i<=K;i++)bin[i]=bin[i-1]<<1; for(int i=2;i<=n;i++)lg[i]=lg[i>>1]+1; ini_dfs(1,0); int Q=rdn(); while(Q--) { for(int i=1;i<=tt;i++)h2[q[i].v]=0; xt2=0; for(int i=1;i<=tt;i++)fa[q[i].v]=0;/// m=rdn(); for(int i=1;i<=m;i++)q[i].v=rdn(),q[i].id=i; build(); solve(); } return 0; }