题解「ZJOI2019 语言」

题意简述:求对于树上每个点 \(x\) ,包含它的链的并集的大小之和,也可描述成,求对于树上每个点 \(x\) ,它能够到达的点的个数之和。

不难发现,对于点 \(x\) 而言,通过树上的路径,它能够到达的点一定构成一棵树。并且这棵树上一定含有包含 \(x\) 点的 \(s_i,t_i\) 。那么也就是说,链并大小就是包含一些关键点 \(s_i,t_i\) 的极小连通子树 \(T\) 的边数。

问题转化到这里,有一个非常经典的结论,包含一些关键点 \(a_1,a_2,a_3...a_n\) 的极小连通子树的边数为 \(|T_e|=\sum_i dep_i-\sum_{i=2}^n dep_{\text{lca}(a_i,a_{i-1})}-dep_{\text{lca}(a_1,a_2...a_n)}\),其中 \(a_1,a_2,a_3..a_n\) 按 \(\texttt{dfs}\) 序从小到大排列。

那么对于每一个点 \(x\),找出包含 \(x\) 的所有路径,并且根据上式求出极小连通子树的边数即可,但是这样做,处理 \(x\) 点被哪些路径覆盖就是 \(O(mn \log^2 n)\) 的,难以承受。

考虑如何快速统计所有对 \(x\) 有影响的路径的贡献,可以想到树上差分,对于 \(s_i,t_i\) 进行 \(+1\),对于 \(\text{lca}(s_i,t_i)\) 进行 \(-1\), 对于 \(fa_{\text{lca}(s_i,t_i)}\) 进行 \(-1\)。对于每一个点 \(x\),我们用桶来存储 \(m\) 条路径对第 \(x\) 个点的覆盖情况。这么统计点 \(x\) 被哪些路径覆盖就是 \(O(nm)\) 的。

观察一下,我们发现 \(\text{dfs}\) 时将儿子的桶与父亲的桶合并,很多位置是空的,没必要统计。并且每次都要重新暴力计算一遍最小连通子树 \(T\) 的边数,显然不是最优的。

不妨把桶换成线段树。点 \(x\) 的线段树中,区间 \([l,r]\) 表示被选中的关键点 \(\text{dfs}\) 序\(\in[l,r]\) 时,极小连通子树 \(T\) 的边数。再维护两个量 \(mx,mn\) 表示当前区间 \([l,r]\) 内被选中的关键点 \(\texttt{dfs}\) 序的最大值与最小值,\(sum\) 表示当前区间中被选中的关键点构成的极小连通子树 \(T\) 的边数,在叶子节点上存储一个 \(cnt\) 统计 \([l,l]\) 的贡献,相当于之前的桶。稍微维护一下,总时间复杂度为 \(O(mn \log n)\)。

这样的时间复杂度仍然可以优化。想一想,根据差分,父亲节点的线段树,一定与儿子节点线段树的信息是重合的,和之前的桶向上合并一样,我们将儿子节点和父亲节点的线段树合并,使用均摊时间复杂度为 \(O(n \log n)\) 的线段树合并即可。

若使用倍增/树剖 \(\text{LCA}\),总时间复杂度为 \(O(n \log^2 n+m \log n)\);使用 \(O(n\log n)-O(1)\) 的 \(LCA\) ,总时间复杂度为 \(O(n \log n+m \log n)\)。

Show the Code

#include<cstdio>
typedef long long ll;
/*------------------------Normal I/O&handmade STL--------------------------*/ 
inline int read() {
	register int x=0,f=1;register char s=getchar();
	while(s>'9'||s<'0') {if(s=='-') f=-1;s=getchar();}
	while(s>='0'&&s<='9') {x=x*10+s-'0';s=getchar();}
	return x*f; 
} 
inline void swap(int &x,int &y) {int tmp=y;y=x;x=tmp;} 
/*------------------------Tree--------------------------*/ 
int cnt=0,num=0,tot=0;
int dep[100005],dfn[100005],rev[100005];
int h[100005],to[200005],ver[200005],f[100005][25];
inline void AddEdge(int x,int y) {to[++cnt]=y;ver[cnt]=h[x];h[x]=cnt;}
inline void prework(int x) {
	int fa=f[x][0];dfn[x]=++num;rev[num]=x;
	for(register int i=1;i<=20;++i) f[x][i]=f[f[x][i-1]][i-1];
	for(register int i=h[x];i;i=ver[i]) {
		int y=to[i];if(y==fa) continue;
		dep[y]=dep[x]+1;f[y][0]=x;prework(y); 
	}
}
inline int LCA(int x,int y) {
	if(!x||!y) return 0; 
	if(dep[x]>dep[y]) swap(x,y);//dep[x]<=dep[y]
	for(register int i=20;i>=0;--i) if(dep[x]<=dep[f[y][i]]) y=f[y][i];
	if(x==y) return x;
	for(register int i=20;i>=0;--i) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
	return f[x][0];
}
/*------------------------SegmentTree--------------------------*/
struct Segment {int mn,mx,cnt;ll sum;}t[4000005];
int lson[4000005],rson[4000005],rt[100005];
inline void pushup(int p) {
	t[p].mn=t[lson[p]].mn? t[lson[p]].mn:t[rson[p]].mn;
	t[p].mx=t[rson[p]].mx? t[rson[p]].mx:t[lson[p]].mx;
	t[p].sum=t[lson[p]].sum+t[rson[p]].sum-dep[LCA(rev[t[lson[p]].mx],rev[t[rson[p]].mn])];//???
}
inline void modify_add(int &p,int l,int r,int dfnId,int val) {
	if(!p) p=++tot;
	if(l==r) {
		t[p].cnt+=val;
		t[p].mx=t[p].mn=(t[p].cnt>0? dfnId:0);
		t[p].sum=(t[p].cnt>0? dep[rev[dfnId]]:0);
		return;
	}
	int mid=l+r>>1;
	if(dfnId<=mid) modify_add(lson[p],l,mid,dfnId,val);
	else modify_add(rson[p],mid+1,r,dfnId,val);
	pushup(p);
}
inline int merge(int x,int y,int l,int r) {
	if(!x||!y) return x|y;
	if(l==r) {t[x].cnt+=t[y].cnt;t[x].mx=t[x].mn=(t[x].cnt>0? l:0);t[x].sum=(t[x].cnt>0? dep[rev[l]]:0);return x;}
	int mid=l+r>>1;
	lson[x]=merge(lson[x],lson[y],l,mid);
	rson[x]=merge(rson[x],rson[y],mid+1,r);
	pushup(x); return x;
}
/*------------------------Solution--------------------------*/
ll ans=0;
inline void PathAdd(int x,int y,int dfnId) {
	int z=LCA(x,y),fa=f[z][0];
	modify_add(rt[x],1,num,dfn[dfnId],1);
	modify_add(rt[y],1,num,dfn[dfnId],1);
	modify_add(rt[z],1,num,dfn[dfnId],-1);
	if(fa) modify_add(rt[fa],1,num,dfn[dfnId],-1);
}
inline void solve(int x) {
	int fa=f[x][0];
	for(register int i=h[x];i;i=ver[i]) {int y=to[i];if(y==fa) continue;solve(y);}
	ans+=t[rt[x]].sum-dep[LCA(rev[t[rt[x]].mn],rev[t[rt[x]].mx])]; 
	if(fa) rt[fa]=merge(rt[fa],rt[x],1,num);
}
int main() {
	int n=read(),m=read();
	for(register int i=1;i<n;++i) {int x=read(),y=read();AddEdge(x,y);AddEdge(y,x);} dep[1]=1;prework(1);
	for(register int i=1;i<=m;++i) {int s=read(),t=read();PathAdd(s,t,s);PathAdd(s,t,t);}
	solve(1); printf("%lld\n",ans>>1);
	return 0;
}
上一篇:[P4145] 花神游历各国 - 线段树


下一篇:kd-tree