P3714 - 树的难题 题解

确实是一道「难题」(orz tzc)。


先来扯一下任意顺序启发式合并 / 按秩合并的事情。设 \(n\) 为所有集合的大小之和。

普通启发式合并:指定顺序合并,\(a,b\)​​​ 合并复杂度为 \(\mathrm O(a)\)​​​ 或 \(\mathrm O(b)\)​​​,合并之后大小是 \(a+b\)​​。每次将复杂度选为 \(\min(a,b)\)​​,总复杂度 \(\mathrm O(n\log n)\)​​。证明:每个原始集合每次作为较小值合并之后,所在集合大小都会翻倍,所以每个原始集合最多贡献复杂度 \(\mathrm O(\log n)\)​​。
普通按秩合并:指定顺序合并,\(a,b\)​​ 合并复杂度为 \(\mathrm O(a)\)​​ 或 \(\mathrm O(b)\)​​,合并之后大小是 \(\max(a,b)\)​​。每次将复杂度选为 \(\min(a,b)\)​​,总复杂度 \(\mathrm O(n)\)​​。证明:每个原始集合每次作为较小值合并之后,所在集合大小大于等于该原始集合,所以以后再也不会作为所在集合大小贡献复杂度,即每个原始集合最多贡献复杂度 \(\mathrm O(1)\)​​​​​ 次。
两者对应到树上分别是 dsu on tree 和长链剖分。虽然过程并不完全一样(dsu on tree 和长链剖分都有将集合大小 +1 的成分),但只分析复杂度的话,等价于每个叶子为一个集合,大小为所在重链 / 长链大小,的启发式合并 / 按秩合并。

任意顺序按秩合并:自己决定顺序合并,\(a,b\)​​ 合并复杂度是 \(\mathrm O(a+b)\)(与 \(\mathrm O(\max(a,b))\) 是等价的),合并之后大小为 \(\max(a,b)\)。按照大小将原始集合从小到大排序,每次合并前两个,总复杂度为 \(\mathrm O(n)\)。证明就比较显然了吧。
任意顺序启发式合并:自己决定顺序合并,\(a,b\) 合并复杂度是 \(\mathrm O(a+b)\),合并之后大小为 \(a+b\)。每次选择最小的两个集合合并(可以用堆实现),总复杂度是 \(\mathrm O(n\log n)\)。证明有点东西:设当前最小集合为 \(a\),次小为 \(b\)。合并完之后过一段时间如果 \(a+b\)​​ 参与了合并,如果是最小值,那么最小值相比上次最小值显然翻了倍,合并完之后最小值显然不降;如果是次小值,合并完之后最小值显然不低于合并前的次小值 \(a+b\),那么最小值也翻了倍。也就是说每个原始集合每参与两次合并,最小值就会翻倍,那么最多贡献复杂度 \(\mathrm O(\log n)\) 次。


考虑点分治。dfs 跑出来连通块内每个点到重心路径的长度 \(len\) 和权值 \(mx\),那么对两个在重心的不同儿子树中的点 \(x,y\),设两个不同的儿子树与重心相连的边分别为 \(c_x,c_y\),那么路径 \(x\to y\) 的权值显然是 \(mx_x+mx_y-[c_x=c_y]a_{c_x}\)。现在考虑求 \(len_x+len_y\in[L,R]\) 的 \(x\to y\) 的最大权值。

由于是求最大值,不可撤销,不能一阶容斥,只能老老实实考虑两两不同的儿子树之间的贡献。考虑动态加子树、贡献子树的话,枚举当前子树的点,我们需要知道 \(c_y\neq c_x\) 且 \(dep_y\in[L-dep_x,R-dep_x]\) 的最大 \(mx_y\),以及 \(c_y=c_x\) 的。乍一看不太好维护,实际上我们可以将儿子树按 \(c\) 排序,把 \(c\) 相同的都放一起,对内部先贡献一波,然后再把总贡献放到外面去对外贡献。于是现在问题转化为根本不需要考虑 \(c\) 的版本。那么动态的话需要基于 \(dep\) 维护一个线段树,容易 2log。

然而我们不动态搞,使用任意顺序合并儿子树的思想,可以做到 1log。对两棵儿子树,搞出一个序列 \(v\),\(v_i\) 表示 \(dep=i\) 的最大 \(mx\)。那么对两棵儿子树,考虑先将它们两个之间的两两点对进行贡献,然后再合并。这样显然是能让任意两个该贡献(即一开始不在同一个原始集合)的点做出贡献的,并且这样就能离线处理了。先考虑合并的时候怎么贡献,依然是要查区间最值,但由于离线了,便可以对 \(dep\)​ 排序用单调队列。但是要排序还是带 log?注意到若使用 bfs,则不排自排。那么合并的话就对有值的 \(dep\) 取个 max。于是现在就变成了一个标准的任意顺序按秩合并,点分治内部可以线性。总复杂度就是 1log。

code(不开 O2 竟然过了!)
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define mp make_pair
#define X first
#define Y second
#define pb push_back
const int inf=0x3f3f3f3f3f3f3f3f;
const int N=200010;
int n,m,L,R;
int a[N];
vector<pair<int,int> > nei[N];
int ans=-inf;
bool vis[N];
vector<vector<int> > vec[N];
pair<int,int> q[N];int head,tail;
int dep[N],dis[N];
int &at(vector<int> &v,int x){return x<v.size()?v[x]:(v.resize(x+1,-inf),v[x]);}
void bfs(int x,vector<int> &w,int col){
	head=tail=0;
	q[tail++]=mp(x,col);dep[x]=1,dis[x]=a[col];
	while(head<tail){
		int x=q[head].X,c=q[head].Y;head++;
		at(w,dep[x])=max(at(w,dep[x]),dis[x]);
		for(int i=0;i<nei[x].size();i++){
			int y=nei[x][i].X,c0=nei[x][i].Y;if(vis[y])continue;
			if(dep[y]==-1)dep[y]=dep[x]+1,dis[y]=dis[x]+(c!=c0)*a[c0],q[tail++]=mp(y,c0);
		}
	}
	for(int i=0;i<tail;i++)dep[q[i].X]=-1;
}
int r[N];
void mrg(vector<int> &v,vector<int> &w,int add=0){
	head=tail=0;
	int now=-1;
	for(int i=(int)(w.size())-1;~i;i--){
		while(now+1<v.size()&&now+1<=R-i){
			now++;
			while(head<tail&&v[r[tail-1]]<=v[now])tail--;
			r[tail++]=now;
		}
		while(head<tail&&r[head]<L-i)head++;
//		if(head<tail)cout<<v[r[head]]<<" "<<w[i]<<" "<<add<<"!\n";
		if(head<tail)ans=max(ans,v[r[head]]+w[i]+add);
	}
	v.resize(max(v.size(),w.size()),-inf);w.resize(v.size(),-inf);for(int i=0;i<v.size();i++)v[i]=max(v[i],w[i]);
}
int sz[N],mxsz[N];
bool cmp(int x,int y){return mxsz[x]<mxsz[y];}
int gtrt(int x=1,int tot=n,int fa=0){
	sz[x]=1,mxsz[x]=0;
	int rt=0;
	for(int i=0;i<nei[x].size();i++){
		int y=nei[x][i].X;if(vis[y]||y==fa)continue;
		rt=min(rt,gtrt(y,tot,x),cmp);
		sz[x]+=sz[y],mxsz[x]=max(mxsz[x],sz[y]);
	}
	mxsz[x]=max(mxsz[x],tot-sz[x]);
	return min(rt,x,cmp);
}
void cdq(int x){
//	cout<<x<<"!!\n";
	vis[x]=true;
	vector<int> col;
	for(int i=0;i<nei[x].size();i++){
		int y=nei[x][i].X,c=nei[x][i].Y;if(vis[y])continue;
		col.pb(c);
		vector<int> w;
		bfs(y,w,c);
//		for(int j=0;j<w.size();j++)cout<<w[j]<<" ";puts("!!!!!!!");
		vec[c].pb(w);
	}
	sort(col.begin(),col.end());col.resize(unique(col.begin(),col.end())-col.begin());
	vector<int> v;v.pb(0);
	vector<pair<int,int> > ord;
	for(int i=0;i<col.size();i++){
		int mx=0;
		for(int j=0;j<vec[col[i]].size();j++)mx=max(mx,(int)vec[col[i]][j].size());
		ord.pb(mp(mx,i));
	}
	sort(ord.begin(),ord.end());
	for(int i=0;i<col.size();i++){
		int x=ord[i].Y;
		vector<pair<int,int> > ord0;
		for(int j=0;j<vec[col[x]].size();j++)ord0.pb(mp(vec[col[x]][j].size(),j));
		sort(ord0.begin(),ord0.end());
		vector<int> w;
		for(int j=0;j<vec[col[x]].size();j++)mrg(w,vec[col[x]][ord0[j].Y],-a[col[x]]);
		mrg(v,w);
	}
	for(int i=0;i<col.size();i++)vec[col[i]].clear();
	for(int i=0;i<nei[x].size();i++){
		int y=nei[x][i].X;if(vis[y])continue;
		cdq(gtrt(y,sz[y]));
	}
}
signed main(){mxsz[0]=inf;
	cin>>n>>m>>L>>R;
	for(int i=1;i<=m;i++)scanf("%lld",a+i);
	for(int i=1;i<n;i++){
		int x,y,c;
		scanf("%lld%lld%lld",&x,&y,&c);
		nei[x].pb(mp(y,c)),nei[y].pb(mp(x,c));
	}
	memset(dep,-1,sizeof(dep));
	cdq(gtrt());
	cout<<ans;
	return 0;
}
上一篇:做题记录


下一篇:luoguP3979 遥远的国度