题意简述:求对于树上每个点 \(x\) ,包含它的链的并集的大小之和,也可描述成,求对于树上每个点 \(x\) ,它能够到达的点的个数之和。
不难发现,对于点 \(x\) 而言,通过树上的路径,它能够到达的点一定构成一棵树。并且这棵树上一定含有包含 \(x\) 点的 \(s_i,t_i\) 。那么也就是说,链并大小就是包含一些关键点 \(s_i,t_i\) 的极小连通子树 \(T\) 的边数。
问题转化到这里,有一个非常经典的结论,包含一些关键点 \(a_1,a_2,a_3...a_n\) 的极小连通子树的边数为 \(|T_e|=\sum_i dep_i-\sum_{i=2}^n dep_{\text{lca}(a_i,a_{i-1})}-dep_{\text{lca}(a_1,a_2...a_n)}\),其中 \(a_1,a_2,a_3..a_n\) 按 \(\texttt{dfs}\) 序从小到大排列。
那么对于每一个点 \(x\),找出包含 \(x\) 的所有路径,并且根据上式求出极小连通子树的边数即可,但是这样做,处理 \(x\) 点被哪些路径覆盖就是 \(O(mn \log^2 n)\) 的,难以承受。
考虑如何快速统计所有对 \(x\) 有影响的路径的贡献,可以想到树上差分,对于 \(s_i,t_i\) 进行 \(+1\),对于 \(\text{lca}(s_i,t_i)\) 进行 \(-1\), 对于 \(fa_{\text{lca}(s_i,t_i)}\) 进行 \(-1\)。对于每一个点 \(x\),我们用桶来存储 \(m\) 条路径对第 \(x\) 个点的覆盖情况。这么统计点 \(x\) 被哪些路径覆盖就是 \(O(nm)\) 的。
观察一下,我们发现 \(\text{dfs}\) 时将儿子的桶与父亲的桶合并,很多位置是空的,没必要统计。并且每次都要重新暴力计算一遍最小连通子树 \(T\) 的边数,显然不是最优的。
不妨把桶换成线段树。点 \(x\) 的线段树中,区间 \([l,r]\) 表示被选中的关键点 \(\text{dfs}\) 序\(\in[l,r]\) 时,极小连通子树 \(T\) 的边数。再维护两个量 \(mx,mn\) 表示当前区间 \([l,r]\) 内被选中的关键点 \(\texttt{dfs}\) 序的最大值与最小值,\(sum\) 表示当前区间中被选中的关键点构成的极小连通子树 \(T\) 的边数,在叶子节点上存储一个 \(cnt\) 统计 \([l,l]\) 的贡献,相当于之前的桶。稍微维护一下,总时间复杂度为 \(O(mn \log n)\)。
这样的时间复杂度仍然可以优化。想一想,根据差分,父亲节点的线段树,一定与儿子节点线段树的信息是重合的,和之前的桶向上合并一样,我们将儿子节点和父亲节点的线段树合并,使用均摊时间复杂度为 \(O(n \log n)\) 的线段树合并即可。
若使用倍增/树剖 \(\text{LCA}\),总时间复杂度为 \(O(n \log^2 n+m \log n)\);使用 \(O(n\log n)-O(1)\) 的 \(LCA\) ,总时间复杂度为 \(O(n \log n+m \log n)\)。
Show the Code
#include<cstdio>
typedef long long ll;
/*------------------------Normal I/O&handmade STL--------------------------*/
inline int read() {
register int x=0,f=1;register char s=getchar();
while(s>'9'||s<'0') {if(s=='-') f=-1;s=getchar();}
while(s>='0'&&s<='9') {x=x*10+s-'0';s=getchar();}
return x*f;
}
inline void swap(int &x,int &y) {int tmp=y;y=x;x=tmp;}
/*------------------------Tree--------------------------*/
int cnt=0,num=0,tot=0;
int dep[100005],dfn[100005],rev[100005];
int h[100005],to[200005],ver[200005],f[100005][25];
inline void AddEdge(int x,int y) {to[++cnt]=y;ver[cnt]=h[x];h[x]=cnt;}
inline void prework(int x) {
int fa=f[x][0];dfn[x]=++num;rev[num]=x;
for(register int i=1;i<=20;++i) f[x][i]=f[f[x][i-1]][i-1];
for(register int i=h[x];i;i=ver[i]) {
int y=to[i];if(y==fa) continue;
dep[y]=dep[x]+1;f[y][0]=x;prework(y);
}
}
inline int LCA(int x,int y) {
if(!x||!y) return 0;
if(dep[x]>dep[y]) swap(x,y);//dep[x]<=dep[y]
for(register int i=20;i>=0;--i) if(dep[x]<=dep[f[y][i]]) y=f[y][i];
if(x==y) return x;
for(register int i=20;i>=0;--i) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
/*------------------------SegmentTree--------------------------*/
struct Segment {int mn,mx,cnt;ll sum;}t[4000005];
int lson[4000005],rson[4000005],rt[100005];
inline void pushup(int p) {
t[p].mn=t[lson[p]].mn? t[lson[p]].mn:t[rson[p]].mn;
t[p].mx=t[rson[p]].mx? t[rson[p]].mx:t[lson[p]].mx;
t[p].sum=t[lson[p]].sum+t[rson[p]].sum-dep[LCA(rev[t[lson[p]].mx],rev[t[rson[p]].mn])];//???
}
inline void modify_add(int &p,int l,int r,int dfnId,int val) {
if(!p) p=++tot;
if(l==r) {
t[p].cnt+=val;
t[p].mx=t[p].mn=(t[p].cnt>0? dfnId:0);
t[p].sum=(t[p].cnt>0? dep[rev[dfnId]]:0);
return;
}
int mid=l+r>>1;
if(dfnId<=mid) modify_add(lson[p],l,mid,dfnId,val);
else modify_add(rson[p],mid+1,r,dfnId,val);
pushup(p);
}
inline int merge(int x,int y,int l,int r) {
if(!x||!y) return x|y;
if(l==r) {t[x].cnt+=t[y].cnt;t[x].mx=t[x].mn=(t[x].cnt>0? l:0);t[x].sum=(t[x].cnt>0? dep[rev[l]]:0);return x;}
int mid=l+r>>1;
lson[x]=merge(lson[x],lson[y],l,mid);
rson[x]=merge(rson[x],rson[y],mid+1,r);
pushup(x); return x;
}
/*------------------------Solution--------------------------*/
ll ans=0;
inline void PathAdd(int x,int y,int dfnId) {
int z=LCA(x,y),fa=f[z][0];
modify_add(rt[x],1,num,dfn[dfnId],1);
modify_add(rt[y],1,num,dfn[dfnId],1);
modify_add(rt[z],1,num,dfn[dfnId],-1);
if(fa) modify_add(rt[fa],1,num,dfn[dfnId],-1);
}
inline void solve(int x) {
int fa=f[x][0];
for(register int i=h[x];i;i=ver[i]) {int y=to[i];if(y==fa) continue;solve(y);}
ans+=t[rt[x]].sum-dep[LCA(rev[t[rt[x]].mn],rev[t[rt[x]].mx])];
if(fa) rt[fa]=merge(rt[fa],rt[x],1,num);
}
int main() {
int n=read(),m=read();
for(register int i=1;i<n;++i) {int x=read(),y=read();AddEdge(x,y);AddEdge(y,x);} dep[1]=1;prework(1);
for(register int i=1;i<=m;++i) {int s=read(),t=read();PathAdd(s,t,s);PathAdd(s,t,t);}
solve(1); printf("%lld\n",ans>>1);
return 0;
}