【DS】P5327 [ZJOI2019]语言

我竟然能独立想出来()

首先树上统计点对问题考虑 dfs 一遍顺便统计,再加上数据结构之类的。

考虑对于第 \(i\) 种语言,\(x,y\) 能开展贸易说明都被 \(i\) 覆盖到了。考虑每种语言覆一个颜色?

不需要。我们发现对于 \(x\) 我们只关心执行了所有跟 \(x\) 有关的覆盖操作(即覆盖到 \(x\))后被覆盖的点的并集的大小。对此直接树上差分下即可。

考虑类虚树,我们现在求出一个集合 \(S=\{x|x\text{是每次覆盖操作的端点}\}\),按 dfs 序排序,那么他们相互间的路径的并集的大小就是 \(\sum_{x\in S}dep[x]-\sum_{x,y\in S,\text{x,y相邻}}dep[LCA(x,y)]\)。(注意 1,n 也算相邻)考虑对当前遍历到的点的贡献就为此。(点数等于边数 +1,但要减去本身)

这个式子画个图就能发现,需要记忆下。

那么如何维护这个式子?我们只需要 multiset+启发式合并 即可。

至于这里为什么是 multiset,因为用 set 的话假如重复了在此删去,回溯后就无了 (好吧,我本来也用 set 的,直到过不去样例)。至于重复点的话,显然对于答案不会有影响,因为加了重复点,又减去它俩的 LCA。

先考虑暴力合并子树答案。加入一个点,只需要考虑前驱后驱的影响。删除也同样。

至于启发式合并的话,swap 之后,要把 y 的答案赋给 x 即可,但这里有个细节,赋的答案不能减去 1,n 相邻的。

#include <bits/stdc++.h>
#define int long long
#define pb push_back
using namespace std;
#define N (int)(1e5+5)
vector<int>g[N];
int n,m;

namespace SP {
	int fa[N],sz[N],top[N],son[N],dep[N],id[N],tot;
	void dfs1(int x,int ff) {
		fa[x]=ff; dep[x]=dep[ff]+1; sz[x]=1;
		for(int y:g[x]) {
			if(y==ff) continue;
			dfs1(y,x); sz[x]+=sz[y];
			if(sz[y]>sz[son[x]]) son[x]=y;
		}
	}
	void dfs2(int x,int tp) {
		top[x]=tp; id[x]=++tot;
		if(son[x]) dfs2(son[x],tp);
		for(int y:g[x]) {
			if(y==fa[x]||y==son[x]) continue;
			dfs2(y,y);
		} 
	}
	void init() {
		dfs1(1,0); dfs2(1,1);
	}
	int LCA(int x,int y) {
		while(top[x]!=top[y]) {
			if(dep[top[x]]<dep[top[y]]) swap(x,y);
			x=fa[top[x]];
		}
		return dep[x]<dep[y]?x:y;
	}
	int dist(int x,int y) {
		return dep[x]+dep[y]-2*dep[LCA(x,y)];
	}
}
using namespace SP;
vector<pair<int,int> >vec[N];

void modify(int x,int y) {
	int d=LCA(x,y);
	vec[x].pb(make_pair(x,1));
	vec[x].pb(make_pair(y,1));
	vec[y].pb(make_pair(x,1));
	vec[y].pb(make_pair(y,1));
	if(fa[d]) vec[fa[d]].pb(make_pair(y,-1)),vec[fa[d]].pb(make_pair(x,-1));
}
struct cmp {
	bool operator () (int x,int y) {
		return id[x]<id[y];
	}
}; 
multiset<int,cmp>s[N]; //这里用 multiset 的原因事实上是为了防止此处删掉 i 但是回溯过去还有 i, 多个 i 排序后在一起对答案没有影响 
int ans[N],dans[N];

void ins(int x,int i) {
//	if(x==4) cout<<"ins "<<i<<endl;
	auto nex=s[x].lower_bound(i),las=nex; int qwq=0; 
	if(las!=s[x].begin()) {
		--las; 
	//	if(x==4) cout<<*las<<" las i "<<i<<" "<<dep[LCA(*las,i)]<<endl;
		++qwq,ans[x]-=dep[LCA(*las,i)];
	}
	if(nex!=s[x].end()) {
	//	if(x==4) cout<<*nex<<" nex i "<<i<<" "<<dep[LCA(*nex,i)]<<endl;
		++qwq; ans[x]-=dep[LCA(*nex,i)];
	}
	if(qwq==2) {
	//	if(x==4) cout<<*las<<" las nex "<<*nex<<" "<<dep[LCA(*las,*nex)]<<endl;
		ans[x]+=dep[LCA(*las,*nex)];
	}
	s[x].insert(i); ans[x]+=dep[i];
}

void del(int x,int i) {
	auto it=s[x].find(i),nex=it,las=it; int qwq=0;
	if(nex!=s[x].end()) {
		++nex; if(nex!=s[x].end()) ++qwq,ans[x]+=dep[LCA(i,*nex)];
	}
	if(las!=s[x].begin()) {
		--las; if(las!=s[x].begin()) ++qwq,ans[x]+=dep[LCA(i,*las)];		
	}
	if(qwq==2) ans[x]-=dep[LCA(*las,*nex)];
	s[x].erase(s[x].find(i)); ans[x]-=dep[i];
}

void dfs(int x,int ff) {
	for(int y:g[x]) {
		if(y==ff) continue;
		dfs(y,x);
		if(s[y].size()>s[x].size()) s[x].swap(s[y]),ans[x]=ans[y];
		//ans[x]+=ans[y];
		for(int i:s[y]) {
			ins(x,i);
		}
	}
//	s[x].insert(x);
	/*for(int i:s[x]) {
		cout<<x<<" : pre : "<<i<<'\n';
	}*/
	for(auto i:vec[x]) {
		if(i.second==1) {
			ins(x,i.first);
			if(x==4) {
			//	cout<<i.first<<" ";
			}
		}
		else del(x,i.first),del(x,i.first);
	}
	if(x==4) {
	/*	cout<<'\n';
		for(int i:s[x]) cout<<i<<" ";*/
	}
	if(s[x].size()>=2) {
		auto p=s[x].end(); --p; 
	//	if(x==4) cout<<dep[LCA(*s[x].begin(),*p)]<<":\n";
		dans[x]-=dep[LCA(*s[x].begin(),*p)];
	}
	/*if(s[x].size()>=2) {
		ans[x]-=dist(*s[x].begin(),*s[x].end());
	}*/
/*	tmp.clear();
	for(int i:s[x]) {
		tmp.pb(i);
	}
	sort(tmp.begin(),tmp.end(),cmp);
	for(int i=0;i<tmp.size();i++) ans+=dep[tmp[i]];
	for(int i=1;i<tmp.size();i++) ans-=dep[LCA(tmp[i-1],tmp[i])];
	int d=0;
	for(int i=1;i<tmp.size();i++) {
		if(!d) d=LCA(tmp[i-1],tmp[i]); else d=LCA(d,tmp[i]);
	}
	if(d) ans-=dep[d];*/
}

signed main() {
//	freopen("language5.in","r",stdin);
	cin.tie(0);
	std::ios::sync_with_stdio(false);
	cin>>n>>m;
	for(int i=1;i<n;i++) {
		int x,y; cin>>x>>y;
		g[x].pb(y); g[y].pb(x);
	}
	SP::init();
	while(m--) {
		int x,y; cin>>x>>y;
		modify(x,y);
	}
	dfs(1,0); int res=0;
	for(int i=1;i<=n;i++) res+=ans[i]+dans[i];
	//for(int i=1;i<=n;i++) cout<<ans[i]<<' '; 
	cout<<res/2;
//	cout<<ans[4];
	return 0;
}
上一篇:每日一道算法(统计字符)


下一篇:(3)tomcat源代码分析环境的搭建