把集合A[i]看作i点的前驱点集合,建成一个DAG,并新建超级源S,向每个前驱集合为空的点连边,那么B[i]就是S到i的必经点集合。
首先使用Lengauer-Tarjan算法建立出以S为起点的Dominator Tree,那么B[i]就是i在树上的所有祖先。
对于一个询问,构造出虚树,然后统计虚树上每一条边上的点数,累加即可。
时间复杂度$O(n+m+q\log n)$。
#include<cstdio>
#include<algorithm>
using namespace std;
const int N=200010,M=500010;
int n,m,Q,i,x,q[N],a[N],vis[N],tot,t,ans;
int g1[N],g2[N],gd[N],v[M*3+N],nxt[M*3+N],ed;
int cnt,dfn[N],id[N],fa[N],f[N],mn[N],sd[N],idom[N];
int d[N],size[N],son[N],top[N],st[N],en[N];
inline void read(int&a){char c;while(!(((c=getchar())>='0')&&(c<='9')));a=c-'0';while(((c=getchar())>='0')&&(c<='9'))(a*=10)+=c-'0';}
inline void add(int*g,int x,int y){v[++ed]=y;nxt[ed]=g[x];g[x]=ed;}
int F(int x){
if(f[x]==x)return x;
int y=F(f[x]);
if(sd[mn[x]]>sd[mn[f[x]]])mn[x]=mn[f[x]];
return f[x]=y;
}
void dfs(int x){
id[dfn[x]=++cnt]=x;
for(int i=g1[x];i;i=nxt[i])if(!dfn[v[i]])dfs(v[i]),fa[dfn[v[i]]]=dfn[x];
}
void tarjan(int S){
int i,j,k,x;
for(cnt=0,i=1;i<=n;i++)gd[i]=dfn[i]=id[i]=fa[i]=idom[i]=0,f[i]=sd[i]=mn[i]=i;
dfs(S);
for(i=n;i>1;i--){
for(j=g2[id[i]];j;j=nxt[j])F(k=dfn[v[j]]),sd[i]=sd[i]<sd[mn[k]]?sd[i]:sd[mn[k]];
add(gd,sd[i],i);
for(j=gd[f[i]=x=fa[i]];j;j=nxt[j])F(k=v[j]),idom[k]=sd[mn[k]]<x?mn[k]:x;
gd[x]=0;
}
for(i=2;i<=n;add(gd,idom[i],i),i++)if(idom[i]!=sd[i])idom[i]=idom[idom[i]];
}
void dfs1(int x){
d[x]=d[idom[x]]+1,size[x]=1;
for(int i=gd[x];i;i=nxt[i]){
dfs1(v[i]),size[x]+=size[v[i]];
if(size[v[i]]>size[son[x]])son[x]=v[i];
}
}
void dfs2(int x,int y){
top[x]=y;st[x]=++cnt;
if(son[x])dfs2(son[x],y);
for(int i=gd[x];i;i=nxt[i])if(v[i]!=son[x])dfs2(v[i],v[i]);
en[x]=cnt;
}
inline int lca(int x,int y){
for(;top[x]!=top[y];x=idom[top[x]])if(d[top[x]]<d[top[y]])swap(x,y);
return d[x]<d[y]?x:y;
}
inline int cmp(int x,int y){return st[x]<st[y];}
int main(){
for(read(n),n++,i=1;i<n;i++){
read(m);
if(!m)add(g1,n,i),add(g2,i,n);
while(m--)read(x),add(g1,x,i),add(g2,i,x);
}
tarjan(n),cnt=0,dfs1(1),dfs2(1,1);
for(read(Q);Q--;printf("%d\n",ans)){
for(read(m),tot=i=0;i<m;i++){
read(x);
if(!vis[x=dfn[x]])vis[a[++tot]=x]=1;
}
vis[a[++tot]=1]=1;
m=tot,sort(a+1,a+m+1,cmp);
for(i=1;i<m;i++)if(!vis[x=lca(a[i],a[i+1])])vis[a[++tot]=x]=1;
m=tot,sort(a+1,a+m+1,cmp);
for(ans=0,q[t=1]=a[1],i=2;i<=m;q[++t]=a[i++]){
while(st[a[i]]<st[q[t]]||en[a[i]]>en[q[t]])t--;
ans+=d[a[i]]-d[q[t]];
}
for(i=1;i<=m;i++)vis[a[i]]=0;
}
return 0;
}