cf1118F2 Tree Cutting (Hard Version)

定一个有 \(n\) 个节点的树, 结点可能有颜色, 共 \(k\) 种颜色, 颜色编号\(1...k\) ,每种颜色都出现。有的点没有颜色, 用 \(0\) 表示. 将其删去 \(k-1\) 条边, 即划分成 \(k\) 个联通块, 使每个联通块中恰好含一种颜色, 颜色为 \(0\) 的节点可以在任意联通块中. 求划分的方案数. ( 无解输出 \(0\) , 答案对 \(998244353\) 取模. )

\(2\leq n\leq 3\cdot 10^5,2\leq k\leq n\)

把每种颜色的点两两之间的边涂色 . 这些边是一定不能被断的 . 并且,这些边相连的点颜色要和当前颜色相同 .

这个时间可以考虑求 \(lca\) ,再对每个点 \(x\) ,到 \(lca\) 路径上的边都要标记,每个点都要染色 .

但是这样是 \(O(n^2+n\log n)\) 的,仔细一想,可以在遇到标记了的点和已经被染色的边的时候停止往上跳动 .

这样,可以做到时间复杂度是 \(O(n+n\log n)\) .

接下来,对于没有标记的边是可以断的 . 但是,随便断,会造成有些联通快中的点都为 \(0\) .

考虑这样一个 \(dp\) .

\(f(i,0/1)\) 表示节点 \(i\) ,目前的联通快全是 \(0\) / 有颜色的断边方法 .

转移的时候需要分类 :

  1. \(a_i=0\)

    \(f(i,0)=\prod f(j,0)\)

    \(f(i,1)=\sum f(j,1)\frac{\prod f(k,0)}{f(j,0)}\)

  2. \(a_i\not=0\)

    \(f(i,1)=\prod f(j,1),if\ a_j=a_i\)

    \(\ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \prod f(j,0), if(a_j\not=a_i)\)

为了防止根节点颜色为 \(0\) ,不好讨论的情况,我强制根节点是一个有颜色的节点.

答案就是 \(f(root,1)\) .

时间复杂度 : \(O(n\log n)\)

空间复杂度 : \(O(n)\)

code
#include<bits/stdc++.h>
using namespace std;
inline int read(){
	char ch=getchar();
	while(ch<'0'||ch>'9')ch=getchar();
	int res=0;
	while(ch>='0'&&ch<='9'){
		res=res*10+ch-'0';
		ch=getchar();
	}
	return res;
}
inline void print(int res){
	if(res==0){
		putchar('0');
		return;
	}
	int a[10],len=0;
	while(res>0){
		a[len++]=res%10;
		res/=10;
	}
	for(int i=len-1;i>=0;i--)
		putchar('0'+a[i]);
}
const int mod=998244353;
int n,k;
int a[300010];
unordered_map<long long,int>Map;
vector<pair<int,int> >g[300010];
vector<int>v[300010];
int dep[300010],anc[300010][20];
bool ok[300010];
int f[300010][2];
void dfs(int x,int fa){
	anc[x][0]=fa;
	for(int i=0;i<(int)g[x].size();i++){
		int to=g[x][i].first,id=g[x][i].second;
		if(to==fa)continue;
		dep[to]=dep[x]+1;
		dfs(to,x);
	}
}
void build(){
	for(int k=0;k+1<20;k++){
		for(int i=0;i<n;i++){
			if(anc[i][k]==-1)anc[i][k+1]=-1;
			else anc[i][k+1]=anc[anc[i][k]][k];
		}
	}
}
int lca(int u,int v){
	if(dep[u]>dep[v])swap(u,v);
	for(int k=0;k<20;k++)if((dep[u]-dep[v])>>k&1)v=anc[v][k];
	if(u==v)return u;
	for(int k=19;k>=0;k--)if(anc[u][k]!=anc[v][k])u=anc[u][k],v=anc[v][k];
	return anc[u][0];
}
inline int ksm(int x,int k){
	if(k==0)return 1;
	int res=ksm(x,k>>1);
	res=1ll*res*res%mod;
	if(k&1)res=1ll*res*x%mod;
	return res;
}
void get(int x,int fa){
	for(int i=0;i<(int)g[x].size();i++){
		int to=g[x][i].first;
		if(to==fa)continue;
		get(to,x);
	}
	if(a[x]==0){
		f[x][0]=1;
		int res=1;
		for(int i=0;i<(int)g[x].size();i++){
			int to=g[x][i].first;
			if(to==fa)continue;
			f[x][0]=1ll*f[x][0]*f[to][0]%mod;
			res=1ll*res*f[to][0]%mod;
		}
		f[x][1]=0;
		for(int i=0;i<(int)g[x].size();i++){
			int to=g[x][i].first;
			if(to==fa)continue;
			int tmp=1ll*res*ksm(f[to][0],mod-2)%mod*f[to][1]%mod;
			f[x][1]=(f[x][1]+tmp)%mod;
		}
		f[x][0]=(f[x][0]+f[x][1])%mod;
	}
	else{
		f[x][0]=0;
		f[x][1]=1;
		for(int i=0;i<(int)g[x].size();i++){
			int to=g[x][i].first;
			if(to==fa)continue;
			if(a[to]==a[x])f[x][1]=1ll*f[x][1]*f[to][1]%mod;
			else f[x][1]=1ll*f[x][1]*f[to][0]%mod;
		}
		if(fa!=-1){
			int id=Map[1ll*x*n+fa];
			if(ok[id]){
				f[x][0]=(f[x][0]+f[x][1])%mod;
			}
		}
	}
}
int main(){
	n=read();k=read();
	for(int i=0;i<n;i++){
		a[i]=read();
		v[a[i]].push_back(i);
	}
	for(int i=0;i<n-1;i++){
		int u=read()-1,v=read()-1;
		g[u].push_back(make_pair(v,i));
		g[v].push_back(make_pair(u,i));
		Map[1ll*u*n+v]=i;
		Map[1ll*v*n+u]=i;
	}
	int root;
	for(int i=0;i<n;i++){
		if(a[i]!=0){
			root=i;
			break;
		}
	}
	dfs(root,-1);
	build();
	memset(ok,true,sizeof(ok));
	for(int i=1;i<=k;i++)if((int)v[i].size()>0){
		int r=v[i][0];
		for(int j=1;j<(int)v[i].size();j++){
			int x=v[i][j];
			r=lca(r,x);
		}
		for(int j=0;j<(int)v[i].size();j++){
			int x=v[i][j];
			while(x!=r){
				int p=anc[x][0];
			//	cout<<x+1<<" "<<p+1<<endl;
				int id=Map[1ll*x*n+p];
				if(!ok[id])break;
				ok[id]=false;
				x=p;
			}
			x=v[i][j];
			while(x!=r){
				int p=anc[x][0];
				if(a[p]==a[x])break;
				if(a[p]!=0&&a[x]!=a[p]){
					cout<<"0\n";
					return 0;
				}
				a[p]=a[x];
				x=p;
			}
		}
	}
	get(root,-1);
	print(f[root][1]);
	putchar('\n');
	return 0;
}
/*inline? ll or int? size? min max?*/
/*
5 2
2 0 0 1 2
1 2
2 3
2 4
2 5
*/
/*
7 3
0 1 0 2 2 3 0
1 3
1 4
1 5
2 7
3 6
4 7
*/
上一篇:Cutting Bamboos 主席树+二分 牛客


下一篇:Hadoop之后:实时数据的未来