题目链接
要求“恰好有$k$个回合不平局”,考虑先求“至少有$k$个回合不平局”。
设$f[u][k]$为$u$的子树内,选出$k$对异色祖孙节点的方案数。树上背包走起!
有两种情况,即$u$有没有算进去。先看$u$没被算进去的情况:$$f[u][k]=\sum_{k_1+k_2+k_3+\dots +k_p=k}\prod f[son_1][k_1]\times f[son_2][k_2]\times f[son_3][k_3]\times f[son_p][k_p]$$
不能直接枚举,考虑每次将子树合并:$$f[u][j+k]=\sum_{son}f[son][j]\times f[u][k]$$
这里注意枚举的上界是“已知最大异色点对数量”,可以保证单个子树的枚举时间是$O(n)$的。感性理解:每个点对仅会在$LCA$处贡献一次复杂度。
同时为了避免覆盖的情况,需要额外使用一个一维数组。这里有一个很奇怪的问题,就是使用$memset$清零这个小数组时,找不到正确的长度表达式。见代码注释。
设$sum[u][c]$表示$u$子树内$c$颜色(仅有$0/1$)的点的数量,那么考虑$u$和子树内的异色点配对,有$$f[u][k]+=f[u][k-1]\times(\,sum[u][c\hat{}1]-(j-1)\,)\quad(sum[u][c\hat{}1]>j-1)$$
接下来设$G(k)=f[1][k]\times(m-k)!$,即在选出$k$个异色祖孙点对之后,剩下随意组合的方案总数,我们暂且认为它是至少选出了$k$对的方案总数。
但是这样显然会算重,因为随意组合的点对中也会出现异色祖孙点对。设$H(i)$为恰好选出了$i$个异色祖孙点对的方案总数,可以发现:对于每个$H(i)$,它对$G(k)$的贡献为$\tbinom{i}{k}H(i)$。
举个栗子,假设$i=4,k=3$,有①②③④四个异色祖孙点对。那么有可能是①②③被算在$f[1][k]$内,④是*组合出来的;也有可能是①②④被算在$f[1][k]$内,③是*组合出来的……总共有$\tbinom{i}{k}$种可能。
于是乎$G(k)=\sum_{i=k}^m \tbinom{i}{k}H(i)$,可以进行二项式反演。二项式反演的公式:$$g(k)=\sum_{i=k}^{\infty}\tbinom{i}{k}f(i)\Leftrightarrow f(i)=\sum_{k=i}^{\infty}(-1)^{k-i}\tbinom{k}{i}g(k)$$
代入即可。
$G(k)$并不是“至少选出了$k$对的方案总数”,因为它重复次数太多。笔者认为它起到一个桥梁的作用,方便我们使用二项式反演来求得答案。但不能否认,它内涵的“至少”的概念,和题目中给的“恰好”的概念相对,确实是我们解题的起手式。
程序(100分):
#include<iostream> #include<cstdio> #include<cstring> #include<cmath> #include<algorithm> #include<vector> #include<queue> #include<map> #include<set> #define IL inline #define RG register #define _1 first #define _2 second using namespace std; typedef long long LL; const int N=5000; const LL mod=998244353; int n,m; bool col[N+3]; struct Edge{ int to,nxt; }e[(N<<1)+3]; int h[N+3],top; IL void add(int u,int v){ top++; e[top].to=v; e[top].nxt=h[u]; h[u]=top; } IL LL add(LL x,LL y){return (x+y)%mod;} IL LL mul(LL x,LL y){return x*y%mod;} IL void fadd(LL &x,LL y){x=add(x,y);} IL void fmul(LL &x,LL y){x=mul(x,y);} IL LL qpow(LL a,LL b){ LL ans=1; for(;b;fmul(a,a),b>>=1) if(b&1) fmul(ans,a); return ans; } LL fac[N+3],inv[N+3]; IL void init(){ fac[0]=fac[1]=inv[0]=inv[1]=1; for(int i=2;i<=m;i++) fac[i]=mul(fac[i-1],i); inv[m]=qpow(fac[m],mod-2); for(int i=m;i>2;i--) inv[i-1]=mul(inv[i],i); } IL LL C(int n,int m){ return mul(mul(fac[n],inv[m]),inv[n-m]); } int sz[N+3],sum[N+3][2]; //sz:子树内最大点对数 LL f[N+3][N+3],g[N+3]; void dfs(int u,int fr){ f[u][0]=1; for(int i=h[u];~i;i=e[i].nxt){ int v=e[i].to; if(v==fr) continue; dfs(v,u); sum[u][0]+=sum[v][0]; sum[u][1]+=sum[v][1]; for(int j=0;j<=sz[u]+sz[v];j++) g[j]=0; // memset(g,0,(sz[u]+sz[v])*(sizeof(LL))); for(int j=0;j<=sz[u];j++) for(int k=0;k<=sz[v];k++) fadd(g[j+k], mul(f[u][j],f[v][k]) ); sz[u]+=sz[v]; for(int j=0;j<=sz[u];j++) f[u][j]=g[j]; } sum[u][col[u]]++; for(int i=sz[u];i>=0;i--) if(sum[u][col[u]^1]-i>0) fadd(f[u][i+1], mul(f[u][i],sum[u][col[u]^1]-i) ); if(f[u][sz[u]+1]) sz[u]++; } int main(){ freopen("p2.in","r",stdin); freopen("p2.out","w",stdout); scanf("%d\n",&n); m=n>>1; for(int i=1;i<=n;i++) col[i]=getchar()-'0'; memset(h,-1,sizeof h); top=-1; for(int i=1;i<n;i++){ int u,v; scanf("%d%d",&u,&v); add(u,v); add(v,u); } init(); memset(sz,0,sizeof sz); memset(sum,0,sizeof sum); memset(f,0,sizeof f); dfs(1,0); for(int i=0;i<=sz[1];i++) fmul(f[1][i],fac[m-i]); for(int k=0;k<=m;k++){ LL ans=0,t=1; for(int i=k;i<=sz[1];i++,fmul(t,mod-1)) fadd(ans, mul(t, mul(C(i,k),f[1][i]) ) ); printf("%lld\n",ans); } return 0; }View Code