\(\text{Problem}:\)[NOI Online #2 提高组] 游戏
\(\text{Solution}:\)
设 \(f_{k}\) 表示恰好非平局回合数为 \(k\) 的方案数,\(g_{k}\) 表示钦定非平局回合数为 \(k\) 的方案数,有:
\[g_{k}=h_{k}(m-k)!\\ f_{k}=\sum\limits_{i=k}^{m}(-1)^{i-k}\binom{i}{k}g_{i} \]其中 \(h_{k}\) 表示在树上选出点对 \((u,v)\),满足 \(u\) 是 \(v\) 的祖先且 \(u\) 的颜色和 \(v\) 不同的方案数。
设 \(f_{x,i}\) 表示 \(x\) 的子树内选出了 \(i\) 个点对。记 \(y\) 为 \(x\) 的儿子结点,首先与其暴力合并,有:
\[\sum\limits_{j=0}^{i}f_{x,j}\times f_{y,i-j}\rightarrow f_{x,i} \]利用树上背包可以在 \(O(n^{2})\) 的时间复杂度内完成转移。
现在求解新增的答案。考虑 \(x\) 作为白点的情况。设 \(sizb_{x}\) 表示 \(x\) 子树内黑点的个数,有:
\[f_{x,i}\times (sizb_{x}-i)\rightarrow f_{x,i+1} \]对于 \(x\) 为黑点的情况同理。
最后显然有 \(h_{k}=f_{1,k}\),那么可以暴力计算出所有 \(f_{k}\)。总时间复杂度 \(O(n^{2})\)。
\(\text{Code}:\)
#include <bits/stdc++.h>
#pragma GCC optimize(3)
//#define int long long
#define ri register
#define mk make_pair
#define fi first
#define se second
#define pb push_back
#define eb emplace_back
#define is insert
#define es erase
#define vi vector<int>
#define vpi vector<pair<int,int>>
using namespace std; const int N=5010, Mod=998244353;
inline int read()
{
int s=0, w=1; ri char ch=getchar();
while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar(); }
while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch^48), ch=getchar();
return s*w;
}
int n,m,a[N],f[N][N],g[N],siz[N][2],fac[N+5],inv[N+5];
inline int ksc(int x,int p) { int res=1; for(;p;p>>=1, x=1ll*x*x%Mod) if(p&1) res=1ll*res*x%Mod; return res; }
inline int C(int x,int y) { if(x<y||x<0||y<0) return 0; return 1ll*fac[x]*inv[x-y]%Mod*inv[y]%Mod; }
int head[N],maxE; struct Edge { int nxt,to; }e[N<<1];
inline void Add(int u,int v) { e[++maxE].nxt=head[u]; head[u]=maxE; e[maxE].to=v; }
void DFS(int x,int fa)
{
siz[x][a[x]]++;
f[x][0]=1;
for(ri int i=head[x];i;i=e[i].nxt)
{
int v=e[i].to;
if(v==fa) continue;
DFS(v,x);
for(ri int j=min(m,siz[x][0]+siz[x][1]);~j;j--)
{
for(ri int k=min(siz[v][0]+siz[v][1],m-j);k;k--)
{
f[x][j+k]=(f[x][j+k]+1ll*f[x][j]*f[v][k]%Mod)%Mod;
}
}
siz[x][0]+=siz[v][0];
siz[x][1]+=siz[v][1];
}
if(a[x])
{
for(ri int i=min(m-1,siz[x][0]-1);~i;i--) f[x][i+1]=(f[x][i+1]+1ll*f[x][i]*(siz[x][0]-i)%Mod)%Mod;
}
else
{
for(ri int i=min(m-1,siz[x][1]-1);~i;i--) f[x][i+1]=(f[x][i+1]+1ll*f[x][i]*(siz[x][1]-i)%Mod)%Mod;
}
}
signed main()
{
fac[0]=1;
for(ri int i=1;i<=N;i++) fac[i]=1ll*fac[i-1]*i%Mod;
inv[N]=ksc(fac[N],Mod-2);
for(ri int i=N;i;i--) inv[i-1]=1ll*inv[i]*i%Mod;
n=read(), m=n/2;
for(ri int i=1;i<=n;i++) scanf("%1d",&a[i]);
for(ri int i=1;i<n;i++)
{
int u,v;
u=read(), v=read();
Add(u,v), Add(v,u);
}
DFS(1,0);
for(ri int i=0;i<=m;i++) g[i]=1ll*f[1][i]*fac[m-i]%Mod;
for(ri int i=0;i<=m;i++)
{
int ans=0;
for(ri int j=i;j<=m;j++)
{
int w=1ll*C(j,i)*g[j]%Mod;
if((j-i)&1) ans=(ans-w+Mod)%Mod;
else ans=(ans+w)%Mod;
}
printf("%d\n",ans);
}
return 0;
}