https://www.luogu.com.cn/problem/P6478
二项式反演
发现恰好有\(k\)个非平局的回合数,“恰好”这并不好做,这里套路地设\(f_i\)表示恰好有\(i\)个非平局的方案数,\(g_i\)表示“钦定”了\(i\)个非平局的然后剩下的\(m-i\)个回合两人随便选点的方案数。
考虑\(f\)和\(g\)有什么关系,对于\(g_k\)与\(f_i\)(\(i \ge k\)),由于是先钦定\(k\)局,考虑\(i\)局中钦定了哪\(k\)局,换句话说\(f_i\)在\(g_k\)中会被重复计数\(\binom{i}{k}\)次
于是
\[g_k=\sum\limits_{i=k}^{m} \binom{i}{k} f_i \]
由二项式反演可得
\[f_k=\sum\limits_{i=k}^{m} (-1)^{i-k} \binom{i}{k} g_i \]
于是算出\(g\)后我们就可以\(O(m^2)\)得到答案了。
考虑树形dp,令\(dp(u,i)\)表示点\(u\)的子树内,选出\(i\)组有祖孙关系的点对的方案数。
只有两种情况
-
- 不选\(u\)节点
-
- 选\(u\)节点
先来看\(1\),不选\(u\)节点的话就类似树上背包那样合并儿子们的信息
这样我们得到了新的\(dp(u,i)\)。
然后如果选择了点\(u\),那么加入这个点对过后会让\(dp(u,i)\)对\(dp(u,i+1)\)做出贡献(这里的\(dp\)值都是上面更新好的),具体的会使得\(dp(u,i+1)\)加上\(dp(u,i)\times u\)子树内的剩下没有被选的点(对手的点)的个数,剩下的点的个数就是\(u\)子树内对手点的个数\(-i\)。
注意上面这个东西不要在更新的时候覆盖\(dp\)值(可以倒序枚举\(i\)或者再开一个数组)
然后显然\(g_k=dp(1,k)\times (m-k)!\)(即钦定\(k\)组的方案数乘上剩下随便选的方案数)
这样背包貌似是\(O(n^3)\)的...
但对于\(dp(u,i)\),粗略的估计一下\(i\)的上界就是\(siz_u\)。
然后按这个上界枚举的话复杂度会到\(O(n^2)\)。
感性的理解就是两个点\(u,v\)仅会在\(lca(u,v)\)处对复杂度造成贡献。
\(Code:\)
#include <bits/stdc++.h>
using namespace std;
#define maxn 5000
#define N 5555
#define MOD 998244353
inline int qpow(int a,int b=MOD-2,int m=MOD)
{
int ans=1%m;
for(;b;b>>=1,a=1ll*a*a%m)
if(b&1) ans=1ll*ans*a%m;
return ans;
}
inline char nc()
{
char ch=getchar(); while(isspace(ch))ch=getchar();
return ch;
}
int fac[N],ifac[N];
inline int comb(int n,int r)
{return 1ll*fac[n]*ifac[r]%MOD*ifac[n-r]%MOD;}
bool bl[N]; int n,m;
int fst[N<<1],nxt[N<<1],to[N<<1],ec;
void ade(int u,int v) {to[++ec]=v,nxt[ec]=fst[u],fst[u]=ec;}
void addedge(int u,int v) {ade(u,v),ade(v,u);}
#define fedge(i,u) for (int i=fst[u],v=to[i];i;i=nxt[i],v=to[i])
int siz[N],c[2][N],f[N][N],df[N];
void dfs(int u,int ff)
{
f[u][0]=1,siz[u]=1,c[0][u]=bl[u]==0,c[1][u]=bl[u]==1;
fedge(_,u) if(v!=ff)
{
dfs(v,u);
for(int i=0;i<=siz[u];i++)
for(int j=0;j<=siz[v];j++)
{
df[i+j]+=1ll*f[u][i]*f[v][j]%MOD;
if(df[i+j]>=MOD) df[i+j]-=MOD;
}
siz[u]+=siz[v];
c[0][u]+=c[0][v],c[1][u]+=c[1][v];
for(int i=0;i<=siz[u];i++) f[u][i]=df[i];
for(int i=0;i<=siz[u];i++) df[i]=0;
}
int cc=c[!bl[u]][u];
for(int i=0;i<=cc;i++)
{
df[i+1]+=1ll*f[u][i]*(cc-i)%MOD;
if(df[i+1]>=MOD) df[i+1]-=MOD;
}
for(int i=0;i<=siz[u];i++) {f[u][i]+=df[i]; if(f[u][i]>=MOD)f[u][i]-=MOD;}
for(int i=0;i<=cc;i++) df[i]=df[i+1]=0;
}
int ans[N];
int main()
{
fac[0]=ifac[0]=1;
for(int i=1;i<=maxn;i++) fac[i]=1ll*fac[i-1]*i%MOD,ifac[i]=qpow(fac[i]);
scanf("%d",&n),m=n>>1;
for(int i=1;i<=n;i++) bl[i]=nc()=='1';
for(int u,v,i=1;i<n;i++)
scanf("%d%d",&u,&v),addedge(u,v);
dfs(1,0);
for(int i=0;i<=m;i++)
{
ans[i]=0;
for(int j=i;j<=m;j++)
{
ans[i]+=1ll*comb(j,i)*(((j-i)&1)?MOD-1:1)%MOD*f[1][j]%MOD*fac[m-j]%MOD;
if(ans[i]>=MOD) ans[i]-=MOD;
}
printf("%d\n",ans[i]);
}
return 0;
}