暴力树剖做法显然,即使做到两个log也不那么优美。
考虑避免树剖做到一个log。那么容易想到树上差分,也即要对每个点统计所有经过他的路径产生的总贡献(显然就是所有这些路径端点所构成的斯坦纳树大小),并支持在一个log内插入删除合并。
考虑怎么求树上一些点所构成的斯坦纳树大小。由虚树的构造过程容易联想到,这就是按dfs序排序后这些点的深度之和-相邻点的lca的深度之和(首尾视作相邻),也就相当于按dfs序遍历所有要经过的点并回到原点的路径长度/2。
这个东西显然(应该)可以set启发式合并维护,但同样就变成了两个log。可以改为线段树合并,线段树上每个节点维护该dfs序区间内dfs序最小和最大的被选中节点,合并时减去跨过两区间的一对相邻点的lca的深度即可。这需要计算O(nlogn)次lca,使用欧拉序rmq做到O(1)lca查询就能以总复杂度O(nlogn)完成。
#include<bits/stdc++.h> using namespace std; #define ll long long #define N 100010 char getc(){char c=getchar();while ((c<'A'||c>'Z')&&(c<'a'||c>'z')&&(c<'0'||c>'9')) c=getchar();return c;} int gcd(int n,int m){return m==0?n:gcd(m,n%m);} int read() { int x=0,f=1;char c=getchar(); while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();} while (c>='0'&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=getchar(); return x*f; } int n,m,p[N],dfn[N],id[N],fa[N],deep[N],cnt,t; struct data{int to,nxt; }edge[N<<1]; vector<int> ins[N],del[N]; void addedge(int x,int y){t++;edge[t].to=y,edge[t].nxt=p[x],p[x]=t;} void dfs(int k) { dfn[k]=++cnt;id[cnt]=k; for (int i=p[k];i;i=edge[i].nxt) if (edge[i].to!=fa[k]) { fa[edge[i].to]=k; deep[edge[i].to]=deep[k]+1; dfs(edge[i].to); } } namespace euler_tour { int dfn[N],id[N<<1],LG2[N<<1],f[N<<1][19],cnt; void dfs(int k) { dfn[k]=++cnt;id[cnt]=k; for (int i=p[k];i;i=edge[i].nxt) if (edge[i].to!=fa[k]) { dfs(edge[i].to); id[++cnt]=k; } } void build() { dfs(1); for (int i=1;i<=cnt;i++) f[i][0]=id[i]; for (int j=1;j<19;j++) for (int i=1;i<=cnt;i++) if (deep[f[i][j-1]]<deep[f[min(cnt,i+(1<<j-1))][j-1]]) f[i][j]=f[i][j-1]; else f[i][j]=f[min(cnt,i+(1<<j-1))][j-1]; for (int i=2;i<=cnt;i++) { LG2[i]=LG2[i-1]; if ((2<<LG2[i])<=i) LG2[i]++; } } int lca(int x,int y) { if (!x||!y) return 0; x=dfn[x],y=dfn[y]; if (x>y) swap(x,y); if (deep[f[x][LG2[y-x+1]]]<deep[f[y-(1<<LG2[y-x+1])+1][LG2[y-x+1]]]) return f[x][LG2[y-x+1]]; else return f[y-(1<<LG2[y-x+1])+1][LG2[y-x+1]]; } } using euler_tour::lca; ll ans; int root[N]; struct data2{int l,r,cnt,lnode,rnode,ans; }tree[N<<6]; void up(int k) { tree[k].lnode=tree[tree[k].l].lnode;if (!tree[k].lnode) tree[k].lnode=tree[tree[k].r].lnode; tree[k].rnode=tree[tree[k].r].rnode;if (!tree[k].rnode) tree[k].rnode=tree[tree[k].l].rnode; tree[k].ans=tree[tree[k].l].ans+tree[tree[k].r].ans-deep[lca(tree[tree[k].l].rnode,tree[tree[k].r].lnode)]; } int merge(int x,int y,int l,int r) { if (!x||!y) return x|y; if (l==r) { tree[x].cnt+=tree[y].cnt; if (tree[x].cnt==0) tree[x].lnode=tree[x].rnode=tree[x].ans=0; else tree[x].lnode=tree[x].rnode=id[l],tree[x].ans=deep[id[l]]; return x; } int mid=l+r>>1; tree[x].l=merge(tree[x].l,tree[y].l,l,mid); tree[x].r=merge(tree[x].r,tree[y].r,mid+1,r); up(x); return x; } void modify(int &k,int l,int r,int x,int op) { if (!k) k=++cnt; if (l==r) { tree[k].cnt+=op; if (tree[k].cnt==0) tree[k].lnode=tree[k].rnode=tree[k].ans=0; else tree[k].lnode=tree[k].rnode=id[l],tree[k].ans=deep[id[l]]; return; } int mid=l+r>>1; if (x<=mid) modify(tree[k].l,l,mid,x,op); else modify(tree[k].r,mid+1,r,x,op); up(k); } void solve(int k) { for (int i=p[k];i;i=edge[i].nxt) if (edge[i].to!=fa[k]) { solve(edge[i].to); root[k]=merge(root[k],root[edge[i].to],1,n); } for (int i:ins[k]) modify(root[k],1,n,dfn[i],1); for (int i:del[k]) modify(root[k],1,n,dfn[i],-1); ans+=tree[root[k]].ans-deep[lca(tree[root[k]].lnode,tree[root[k]].rnode)]; } int main() { #ifndef ONLINE_JUDGE freopen("a.in","r",stdin); freopen("a.out","w",stdout); const char LL[]="%I64d\n"; #else const char LL[]="%lld\n"; #endif n=read(),m=read(); for (int i=1;i<n;i++) { int x=read(),y=read(); addedge(x,y),addedge(y,x); } dfs(1); euler_tour::build(); for (int i=1;i<=m;i++) { int x=read(),y=read(),z=fa[lca(x,y)]; ins[x].push_back(x);ins[x].push_back(y); ins[y].push_back(x);ins[y].push_back(y); del[z].push_back(x);del[z].push_back(y); del[z].push_back(x);del[z].push_back(y); } cnt=0; solve(1); cout<<ans/2; return 0; }