定义$S_{i}$表示第$i$条链所包含的点的集合,$(x,y)$合法当且仅当$x\ne y$且$\exists i,\{x,y\}\subseteq S_{i}$(答案即$\frac{合法点对数}{2}$),显然后者等价于$y\in \cup_{x\in S_{i}}S_{i}$,因此合法点对数为$\sum_{x=1}^{n}|\cup_{x\in S_{i}}S_{i}|-1$
结论:$链并的大小=链端点所构成的虚树点数=\frac{按照dfs序排序后相邻(包括首尾)两点距离和}{2}+1$
前者显然,后者证明如下:
对每一条边统计经过次数,设其连结的深度较大的点为$x$,那么记$p_{i}=1$当且仅当$i$在$x$子树内(否则$p_{i}=0$),观察可得两个点$x$和$y$经过这条边当且仅当$p_{x}+p_{y}=1$
考虑dfs序的性质:每一个子树一定是一段区间,因此设端点按dfs序排序后为$a_{1},a_{2},...,a_{k}$,$S=\{i|p_{a_{i}}=1\}$一定是一段区间$[l,r]$,观察可得当$[l,r]=\emptyset$或$[l,r]=[1,k]$时该边答案为0,否则答案为2
考虑$[l,r]=\emptyset$或$[l,r]=[1,k]$的条件,即等价于这条边不在虚树上,那么$\frac{按照dfs序排序后相邻(包括首尾)两点距离和}{2}$即为边数,根据树的性质,加1即为点数
根据这个结论,将每条链差分并用线段树合并来找到所有端点,线段树上维护:1.个数(判断是否存在);2.区间最小点;3.区间最大点;4.区间相邻点距离和(最左和最右可以在外面算)即可,如果用st表维护lca可以做到$o(n\log_{2}n)$
1 #include<bits/stdc++.h> 2 using namespace std; 3 #define N 100005 4 #define mid (l+r>>1) 5 struct ji{ 6 int nex,to; 7 }edge[N<<1]; 8 int V,E,n,m,x,y,head[N],dfn[N],id[N],s[N],f[N][21],r[N],ls[N*100],rs[N*100],vis[N*100],mn[N*100],mx[N*100],sum[N*100]; 9 long long ans; 10 void add(int x,int y){ 11 edge[E].nex=head[x]; 12 edge[E].to=y; 13 head[x]=E++; 14 } 15 void dfs(int k,int fa,int sh){ 16 dfn[k]=++x; 17 id[x]=k; 18 s[k]=sh; 19 f[k][0]=fa; 20 for(int i=1;i<=20;i++)f[k][i]=f[f[k][i-1]][i-1]; 21 for(int i=head[k];i!=-1;i=edge[i].nex) 22 if (edge[i].to!=fa)dfs(edge[i].to,k,sh+1); 23 } 24 int lca(int x,int y){ 25 if (s[x]<s[y])swap(x,y); 26 for(int i=20;i>=0;i--) 27 if (s[f[x][i]]>=s[y])x=f[x][i]; 28 if (x==y)return x; 29 for(int i=20;i>=0;i--) 30 if (f[x][i]!=f[y][i]){ 31 x=f[x][i]; 32 y=f[y][i]; 33 } 34 return f[x][0]; 35 } 36 int dis(int x,int y){ 37 return s[x]+s[y]-2*s[lca(x,y)]; 38 } 39 void up(int k){ 40 mn[k]=min(mn[ls[k]],mn[rs[k]]); 41 mx[k]=max(mx[ls[k]],mx[rs[k]]); 42 sum[k]=sum[ls[k]]+sum[rs[k]]; 43 if ((mx[ls[k]])&&(mn[rs[k]]<=n))sum[k]+=dis(id[mx[ls[k]]],id[mn[rs[k]]]); 44 } 45 void update(int &k,int l,int r,int x,int y){ 46 if (!k){ 47 k=++V; 48 mn[k]=n+1; 49 } 50 if (l==r){ 51 vis[k]+=y; 52 if (vis[k]>0)mn[k]=mx[k]=l; 53 else{ 54 mn[k]=n+1; 55 mx[k]=0; 56 } 57 return; 58 } 59 if (x<=mid)update(ls[k],l,mid,x,y); 60 else update(rs[k],mid+1,r,x,y); 61 up(k); 62 } 63 int merge(int k1,int k2){ 64 if ((!k1)||(!k2))return k1+k2; 65 if ((!ls[k1])&&(!rs[k1])){ 66 vis[k1]+=vis[k2]; 67 if (vis[k1]>0){ 68 mn[k1]=min(mn[k1],mn[k2]); 69 mx[k1]=max(mx[k1],mx[k2]); 70 } 71 else{ 72 mn[k1]=n+1; 73 mx[k1]=0; 74 } 75 return k1; 76 } 77 ls[k1]=merge(ls[k1],ls[k2]); 78 rs[k1]=merge(rs[k1],rs[k2]); 79 up(k1); 80 return k1; 81 } 82 void dfs(int k,int fa){ 83 for(int i=head[k];i!=-1;i=edge[i].nex) 84 if (edge[i].to!=fa){ 85 dfs(edge[i].to,k); 86 r[k]=merge(r[k],r[edge[i].to]); 87 } 88 if (mn[r[k]]!=mx[r[k]])ans+=sum[r[k]]+dis(id[mn[r[k]]],id[mx[r[k]]]); 89 } 90 int main(){ 91 scanf("%d%d",&n,&m); 92 memset(head,-1,sizeof(head)); 93 for(int i=1;i<n;i++){ 94 scanf("%d%d",&x,&y); 95 add(x,y); 96 add(y,x); 97 } 98 x=0; 99 dfs(1,0,1); 100 mn[0]=n+1; 101 for(int i=1;i<=m;i++){ 102 scanf("%d%d",&x,&y); 103 int z=lca(x,y); 104 update(r[x],1,n,dfn[x],1); 105 update(r[x],1,n,dfn[y],1); 106 update(r[y],1,n,dfn[x],1); 107 update(r[y],1,n,dfn[y],1); 108 update(r[f[z][0]],1,n,dfn[x],-2); 109 update(r[f[z][0]],1,n,dfn[y],-2); 110 } 111 dfs(1,0); 112 printf("%lld",ans/4); 113 }View Code