\(\text{Problem}:\)PolandBall and Many Other Balls
\(\text{Preface}:\)
这是一道有着多种不同做法的经典好题。由于笔者水平有限,仅在下文介绍三种主流做法。
\(\text{Solution 1}:\) 组合容斥
设 \(f_{k}\) 表示取出 \(k\) 组的方案数。考虑其组合意义,从 \(k\) 组中选 \(i\) 组包含两个球,将包含两个球的位置绑定,就只剩 \(n-i\) 个位置放 \(k\) 个球。故有:
\[f_{k}=\sum\limits_{i=0}^{k}\binom{k}{i}\binom{n-i}{k} \]发现这个式子难以展开推导,故考虑改变其组合意义。发现 \(\binom{k}{i}\binom{n-i}{k}\) 等价于一排有 \(n\) 个球,在前 \(k\) 个中选若干个,在剩余的球中任意选 \(k\) 个的方案数。根据定义,第一次选择的球和第二次选择的 \(k\) 个球是没有交集的。这提示我们容斥求解。
设 \(g_{i}\) 表示恰好有 \(i\) 个球被重复选到的方案数,\(h_{i}\) 表示钦定有 \(i\) 个球被重复选到的方案数。考虑求出 \(h_{i}\)。钦定前 \(k\) 个中有 \(i\) 个被重复选择,剩下 \(k-i\) 个没有选择限制,而 \(n-i\) 个球中只能选出 \(k-i\) 个。故有:
\[\begin{aligned} g_{i}&=\sum\limits_{j=i}^{k}(-1)^{j-i}\binom{j}{i}h_{j}\\ &=\sum\limits_{j=i}^{k}(-1)^{j-i}\binom{j}{i}\binom{k}{j}2^{k-j}\binom{n-j}{k-j} \end{aligned} \]显然 \(f_{k}=g_{0}\),得到:
\[\begin{aligned} f_{k}&=\sum\limits_{j=0}^{k}(-1)^{j}\binom{k}{j}2^{k-j}\binom{n-j}{k-j}\\ &=\frac{k!}{(n-k)!}\sum\limits_{j=0}^{k}\frac{(-1)^{j}(n-j)!}{j!}\cdot\frac{2^{k-j}}{(k-j)!(k-j)!}\\ &=k!\cdot n^{\underline{k}}\sum\limits_{j=0}^{k}\frac{(-1)^{j}}{j!\cdot n^{\underline{j}}}\cdot\frac{2^{k-j}}{(k-j)!(k-j)!} \end{aligned} \]注意到 \(n^{\underline{j}}\) 中如果存在 \(998244353\),那么 \(n^{\underline{k}}\) 中一定存在 \(998244353\)。故以 \(n^{\underline{k}}\) 在模 \(998244353\) 意义下是否为 \(0\) 分段,做两次 \(\text{NTT}\) 即可。时间复杂度 \(O(n\log n)\)。
\(\text{Code 1}:\)
#include <bits/stdc++.h>
#pragma GCC optimize(3)
//#define int long long
#define ri register
#define mk make_pair
#define fi first
#define se second
#define pb push_back
#define eb emplace_back
#define is insert
#define es erase
#define vi vector<int>
#define vpi vector<pair<int,int>>
using namespace std; const int N=135010, Mod=998244353;
inline int read()
{
int s=0, w=1; ri char ch=getchar();
while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar(); }
while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch^48), ch=getchar();
return s*w;
}
int n,K;
int rev[N],r[24][2],fac[N+5],inv[N+5],pw[N+5],ifac[N+5],tfac[N+5];
inline int ksc(int x,int p) { int res=1; for(;p;p>>=1, x=1ll*x*x%Mod) if(p&1) res=1ll*res*x%Mod; return res; }
inline void Get_Rev(int T) { for(ri int i=0;i<T;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(T>>1):0); }
inline void DFT(int T,vector<int> &s,int type)
{
for(ri int i=0;i<T;i++) if(rev[i]<i) swap(s[i],s[rev[i]]);
for(ri int i=2,cnt=1;i<=T;i<<=1,cnt++)
{
int wn=r[cnt][type];
for(ri int j=0,mid=(i>>1);j<T;j+=i)
{
for(ri int k=0,w=1;k<mid;k++,w=1ll*w*wn%Mod)
{
int x=s[j+k], y=1ll*w*s[j+mid+k]%Mod;
s[j+k]=(x+y)%Mod;
s[j+mid+k]=x-y;
if(s[j+mid+k]<0) s[j+mid+k]+=Mod;
}
}
}
if(!type) for(ri int i=0,inv=ksc(T,Mod-2);i<T;i++) s[i]=1ll*s[i]*inv%Mod;
}
inline void NTT(int n,int m,vector<int> &A,vector<int> B)
{
int len=n+m;
int T=1;
while(T<=len) T<<=1;
Get_Rev(T);
A.resize(T), B.resize(T);
for(ri int i=n+1;i<T;i++) A[i]=0;
for(ri int i=m+1;i<T;i++) B[i]=0;
DFT(T,A,1), DFT(T,B,1);
for(ri int i=0;i<T;i++) A[i]=1ll*A[i]*B[i]%Mod;
DFT(T,A,0);
}
signed main()
{
fac[0]=1;
for(ri int i=1;i<=N;i++) fac[i]=1ll*fac[i-1]*i%Mod;
inv[N]=ksc(fac[N],Mod-2);
for(ri int i=N;i;i--) inv[i-1]=1ll*inv[i]*i%Mod;
pw[0]=1;
for(ri int i=1;i<=N;i++) pw[i]=(pw[i-1]<<1)%Mod;
r[23][1]=ksc(3,119), r[23][0]=ksc(ksc(3,Mod-2),119);
for(ri int i=22;~i;i--) r[i][0]=1ll*r[i+1][0]*r[i+1][0]%Mod, r[i][1]=1ll*r[i+1][1]*r[i+1][1]%Mod;
n=read(), K=read();
ifac[0]=tfac[0]=1;
for(ri int i=1;i<=n&&i<=N;i++)
{
ifac[i]=1ll*ifac[i-1]*((n-i+1)%Mod)%Mod;
tfac[i]=tfac[i-1];
if((n-i+1)%Mod) tfac[i]=1ll*tfac[i]*((n-i+1)%Mod)%Mod;
}
vector<int> A,B;
int ct=K+1;
for(ri int i=1;i<=K;i++) if(!ifac[i]) { ct=i; break; }
A.resize(ct), B.resize(ct);
for(ri int i=0;i<ct;i++)
{
if(i&1) A[i]=Mod-1ll*inv[i]*ksc(ifac[i],Mod-2)%Mod;
else A[i]=1ll*inv[i]*ksc(ifac[i],Mod-2)%Mod;
B[i]=1ll*pw[i]*inv[i]%Mod*inv[i]%Mod;
}
NTT(ct,ct,A,B);
for(ri int i=1;i<ct;i++)
printf("%d ",1ll*A[i]*ifac[i]%Mod*fac[i]%Mod);
vector<int>().swap(A);
A.resize(K+1), B.resize(K+1);
for(ri int i=ct;i<=K;i++)
{
if(i&1) A[i]=Mod-1ll*inv[i]*ksc(tfac[i],Mod-2)%Mod;
else A[i]=1ll*inv[i]*ksc(tfac[i],Mod-2)%Mod;
A[i]%=Mod;
B[i]=1ll*pw[i]*inv[i]%Mod*inv[i]%Mod;
}
NTT(K,K,A,B);
for(ri int i=ct;i<=K;i++)
printf("%d ",1ll*A[i]*tfac[i]%Mod*fac[i]%Mod);
puts("");
return 0;
}
\(\text{Solution 2}:\) 特征方程
设 \(f_{i,j}\) 表示前 \(i\) 个球分成 \(j\) 组的方案数,有转移:
\[f_{i,j}=f_{i-1,j}+f_{i-1,j-1}+f_{i-2,j-1} \]设 \(F_{i}(x)\) 表示序列 \(f_{i}\) 的 \(\text{OGF}\),有:
\[F_{i}(x)=(x+1)F_{i-1}(x)+xF_{i-2}(x) \]该递推式的特征方程为:
\[z^{2}-(x+1)z-x=0 \]设两解 \(z_{1},z_{2},z_{1}\geq z_{2}\),由求根公式,有:
\[z_{1}=\frac{x+1+\sqrt{x^{2}+6x+1}}{2},z_{2}=\frac{x+1-\sqrt{x^{2}+6x+1}}{2} \]现在设 \(F_{n}(x)=c_{1}z_{1}^{n}+c_{2}z_{2}^{n}\),由 \(F_{0}(x)=1,F_{1}(x)=1+x\),得到:
\[c_{1}=\frac{z_{1}}{\sqrt{x^{2}+6x+1}},c_{2}=\frac{z_{2}}{\sqrt{x^{2}+6x+1}}\\ F_{n}(x)=\frac{z_{1}^{n+1}-z_{2}^{n+1}}{\sqrt{x^{2}+6x+1}} \]发现 \(z_{2}\) 的常数项为 \(0\),即 \(z_{2}^{n+1}\equiv 0\pmod {x^{n+1}}\),有:
\[F_{n}(x)=\frac{z_{1}^{n+1}}{\sqrt{x^{2}+6x+1}}=\frac{(x+1+\sqrt{x^{2}+6x+1})^{n+1}}{2^{n+1}\sqrt{x^{2}+6x+1}} \]利用多项式快速幂即可在 \(O(n\log n)\) 的时间复杂度内解决本题。
\(\text{Code 2}:\)
#include <bits/stdc++.h>
#pragma GCC optimize(3)
//#define int long long
#define ri register
#define mk make_pair
#define fi first
#define se second
#define pb push_back
#define eb emplace_back
#define is insert
#define es erase
#define vi vector<int>
#define vpi vector<pair<int,int>>
using namespace std; const int N=135010, Mod=998244353;
inline int read()
{
int s=0, w=1; ri char ch=getchar();
while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar(); }
while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch^48), ch=getchar();
return s*w;
}
int n,K;
int rev[N],r[24][2],iiv[N+5];
inline int ksc(int x,int p) { int res=1; for(;p;p>>=1, x=1ll*x*x%Mod) if(p&1) res=1ll*res*x%Mod; return res; }
inline void Get_Rev(int T) { for(ri int i=0;i<T;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(T>>1):0); }
inline void DFT(int T,vector<int> &s,int type)
{
for(ri int i=0;i<T;i++) if(rev[i]<i) swap(s[i],s[rev[i]]);
for(ri int i=2,cnt=1;i<=T;i<<=1,cnt++)
{
int wn=r[cnt][type];
for(ri int j=0,mid=(i>>1);j<T;j+=i)
{
for(ri int k=0,w=1;k<mid;k++,w=1ll*w*wn%Mod)
{
int x=s[j+k], y=1ll*w*s[j+mid+k]%Mod;
s[j+k]=(x+y)%Mod;
s[j+mid+k]=x-y;
if(s[j+mid+k]<0) s[j+mid+k]+=Mod;
}
}
}
if(!type) for(ri int i=0,inv=ksc(T,Mod-2);i<T;i++) s[i]=1ll*s[i]*inv%Mod;
}
inline void NTT(int n,int m,vector<int> &A,vector<int> B)
{
int len=n+m;
int T=1;
while(T<=len) T<<=1;
Get_Rev(T);
A.resize(T), B.resize(T);
DFT(T,A,1), DFT(T,B,1);
for(ri int i=0;i<T;i++) A[i]=1ll*A[i]*B[i]%Mod;
DFT(T,A,0);
}
void GetInv(int n,vector<int> &F,vector<int> G)
{
if(n==1) { F[0]=ksc(G[0],Mod-2); return; }
GetInv((n+1)/2,F,G);
vector<int> A,B;
int T=1;
while(T<=n+n) T<<=1;
Get_Rev(T);
A.resize(T), B.resize(T);
for(ri int i=0;i<n;i++) A[i]=F[i], B[i]=G[i];
DFT(T,A,1), DFT(T,B,1);
for(ri int i=0;i<T;i++) A[i]=(2ll*A[i]%Mod-1ll*B[i]*A[i]%Mod*A[i]%Mod+Mod)%Mod;
DFT(T,A,0);
for(ri int i=0;i<n;i++) F[i]=A[i];
}
void GetDao(int n,vector<int> &A,vector<int> B)
{
for(ri int i=0;i<n-1;i++) A[i]=1ll*(i+1)*B[i+1]%Mod;
A[n-1]=0;
}
void GetJi(int n,vector<int> &A,vector<int> B)
{
for(ri int i=1;i<n;i++) A[i]=1ll*B[i-1]*iiv[i]%Mod;
A[0]=0;
}
void GetLn(int n,vector<int> &F,vector<int> G)
{
vector<int> A,B;
A.resize(n), B.resize(n);
GetDao(n,A,G);
GetInv(n,B,G);
NTT(n,n,A,B);
GetJi(n,F,A);
}
void GetExp(int n,vector<int> &F,vector<int> G)
{
if(n==1) { F[0]=1; return; }
GetExp((n+1)/2,F,G);
vector<int> C;
C.resize(n);
GetLn(n,C,F);
vector<int> A,B;
int T=1;
while(T<=n+n) T<<=1;
Get_Rev(T);
A.resize(T), B.resize(T);
for(ri int i=0;i<n;i++) A[i]=F[i], B[i]=(G[i]-C[i]+Mod)%Mod; B[0]++;
DFT(T,A,1), DFT(T,B,1);
for(ri int i=0;i<T;i++) A[i]=1ll*A[i]*B[i]%Mod;
DFT(T,A,0);
for(ri int i=0;i<n;i++) F[i]=A[i];
}
struct Node { int x,y; }; int I2;
inline Node operator * (Node a,Node b)
{
int w1=(1ll*a.x*b.x%Mod+1ll*a.y*b.y%Mod*I2%Mod)%Mod;
int w2=(1ll*a.x*b.y%Mod+1ll*a.y*b.x%Mod)%Mod;
return (Node){w1,w2};
}
inline Node KSC(Node x,int p) { Node res=(Node){1,0}; for(;p;p>>=1, x=x*x) if(p&1) res=res*x; return res; }
inline bool Check(int x) { return ksc(x,(Mod-1)/2)==1; }
inline int Rand()
{
int w=1ll*rand()*rand()%Mod;
w+=rand()-rand();
w=(w%Mod+Mod)%Mod;
return w;
}
inline int Cipolla(int n)
{
if(!n) return 0;
int a=0;
while(Check((1ll*a*a%Mod-n+Mod)%Mod)) a=Rand();
I2=(1ll*a*a%Mod-n+Mod)%Mod;
int X1=KSC((Node){a,1},(Mod+1)/2).x;
return min(X1,Mod-X1);
}
void Getsqrt(int n,vector<int> &F,vector<int> G)
{
if(n==1) { F[0]=Cipolla(G[0]); return; }
Getsqrt((n+1)/2,F,G);
vector<int> A,B;
A.resize(n), B.resize(n);
GetInv(n,A,F);
for(ri int i=0;i<n;i++) B[i]=G[i];
NTT(n,n,A,B);
for(ri int i=0,inv2=(Mod+1)/2;i<n;i++) F[i]=1ll*(F[i]+A[i])*inv2%Mod;
}
signed main()
{
srand(time(NULL));
iiv[1]=1;
for(ri int i=2;i<=N;i++) iiv[i]=1ll*(Mod-Mod/i)*iiv[Mod%i]%Mod;
r[23][1]=ksc(3,119), r[23][0]=ksc(ksc(3,Mod-2),119);
for(ri int i=22;~i;i--) r[i][0]=1ll*r[i+1][0]*r[i+1][0]%Mod, r[i][1]=1ll*r[i+1][1]*r[i+1][1]%Mod;
n=read(), K=read(); K++;
vector<int> a,F;
a.resize(K), F.resize(K);
a[0]=1, a[1]=6;
if(K>2) a[2]=1;
Getsqrt(K,a,a);
vector<int> G; G=a;
vector<int> H=G;
for(ri int i=0;i<K;i++) G[i]=0;
GetInv(K,G,H);
a[0]++, a[0]%=Mod, a[1]++, a[1]%=Mod;
int inv2=(Mod+1)/2;
int w=1ll*a[0]*inv2%Mod;
for(ri int i=0,inv=ksc(w,Mod-2);i<K;i++) a[i]=1ll*a[i]*inv%Mod*inv2%Mod;
GetLn(K,F,a);
for(ri int i=0;i<K;i++) F[i]=1ll*F[i]*((n+1)%Mod)%Mod;
GetExp(K,F,F);
for(ri int i=0,gg=ksc(w,(n+1)%(Mod-1));i<K;i++) F[i]=1ll*F[i]*gg%Mod;
NTT(K,K,F,G);
for(ri int i=1;i<K;i++) printf("%d ",(i<=n)?F[i]:0);
puts("");
return 0;
}
\(\text{Solution 3}:\) 倍增
设 \(f_{i,j}\) 表示前 \(i\) 个球分成 \(j\) 组的方案数。与 \(\text{Solution 2}\) 不同的,我们考虑另外一种转移。每次转移可以看作两排球并在一起,根据并的位置是否将一组包含两个球划开分类讨论,有转移:
\[f_{x+y,i}=\sum\limits_{j=0}^{i}f_{x,j}\times f_{y,i-j}+\sum\limits_{j=0}^{i-1}f_{x-1,j}\times f_{y-1,i-j-1} \]设 \(F_{i}(x)\) 表示序列 \(f_{i}\) 的 \(\text{OGF}\),有:
\[F_{i+j}(x)=F_{i}(x)F_{j}(x)+xF_{i-1}(x)F_{j-1}(x) \]发现 \(i+j\) 的转移中还需要求出 \(i-1,j-1\),而 \(i+j-1\) 的转移中还需要求出 \(i-2,j-2\)(但求 \(i+j-2\) 也只需 \(i-2,j-2\)),故维护 \(F_{2^{k}},F_{2^{k}-1},F_{2^{k}-2}\) 倍增求解即可。总时间复杂度 \(O(n\log^{2}n)\)。
\(\text{Code 3}:\)
#include <bits/stdc++.h>
#pragma GCC optimize(3)
//#define int long long
#define ri register
#define mk make_pair
#define fi first
#define se second
#define pb push_back
#define eb emplace_back
#define is insert
#define es erase
#define vi vector<int>
#define vpi vector<pair<int,int>>
using namespace std; const int N=135010, Mod=998244353;
inline int read()
{
int s=0, w=1; ri char ch=getchar();
while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar(); }
while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch^48), ch=getchar();
return s*w;
}
int n,K;
struct DP
{
vector<int> a[3];
}f,g;
int rev[N],r[24][2];
inline int ksc(int x,int p) { int res=1; for(;p;p>>=1, x=1ll*x*x%Mod) if(p&1) res=1ll*res*x%Mod; return res; }
inline void Get_Rev(int T) { for(ri int i=0;i<T;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(T>>1):0); }
inline void DFT(int T,vector<int> &s,int type)
{
for(ri int i=0;i<T;i++) if(rev[i]<i) swap(s[i],s[rev[i]]);
for(ri int i=2,cnt=1;i<=T;i<<=1,cnt++)
{
int wn=r[cnt][type];
for(ri int j=0,mid=(i>>1);j<T;j+=i)
{
for(ri int k=0,w=1;k<mid;k++,w=1ll*w*wn%Mod)
{
int x=s[j+k], y=1ll*w*s[j+mid+k]%Mod;
s[j+k]=(x+y)%Mod;
s[j+mid+k]=x-y;
if(s[j+mid+k]<0) s[j+mid+k]+=Mod;
}
}
}
if(!type) for(ri int i=0,inv=ksc(T,Mod-2);i<T;i++) s[i]=1ll*s[i]*inv%Mod;
}
inline void Merge(DP &x,DP y)
{
int len=(int)x.a[0].size()+(int)y.a[0].size()-1;
int lim=len;
len=min(len,K);
lim=min(lim,K+K);
vector<int> w[6];
for(ri int i=0;i<3;i++) w[i]=x.a[i];
for(ri int i=3;i<6;i++) w[i]=y.a[i-3];
int T=1;
while(T<=lim) T<<=1;
Get_Rev(T);
for(ri int i=0;i<6;i++) w[i].resize(T), DFT(T,w[i],1);
for(ri int i=0;i<T;i++)
{
int h0=w[0][i];
int h1=w[1][i];
int h2=w[2][i];
int h3=w[3][i];
int h4=w[4][i];
int h5=w[5][i];
w[0][i]=1ll*h0*h3%Mod;
w[1][i]=1ll*h0*h4%Mod;
w[2][i]=1ll*h1*h4%Mod;
w[3][i]=1ll*h1*h5%Mod;
w[4][i]=1ll*h2*h5%Mod;
}
for(ri int i=0;i<5;i++) DFT(T,w[i],0);
for(ri int i=0;i<3;i++) x.a[i].resize(len+1);
for(ri int i=0;i<=len;i++)
{
x.a[0][i]=w[0][i];
x.a[1][i]=w[1][i];
x.a[2][i]=w[2][i];
if(i) x.a[0][i]=(x.a[0][i]+w[2][i-1])%Mod, x.a[1][i]=(x.a[1][i]+w[3][i-1])%Mod, x.a[2][i]=(x.a[2][i]+w[4][i-1])%Mod;
}
}
signed main()
{
r[23][1]=ksc(3,119), r[23][0]=ksc(ksc(3,Mod-2),119);
for(ri int i=22;~i;i--) r[i][0]=1ll*r[i+1][0]*r[i+1][0]%Mod, r[i][1]=1ll*r[i+1][1]*r[i+1][1]%Mod;
n=read(), K=read();
for(ri int i=0;i<3;i++) f.a[i].resize(1), g.a[i].resize(2);
f.a[0][0]=1;
g.a[0][0]=g.a[1][0]=g.a[0][1]=1;
for(int p=n;p;p>>=1, Merge(g,g)) if(p&1) Merge(f,g);
for(ri int i=1;i<=K;i++) printf("%d ",(i<=n)?f.a[0][i]:0);
puts("");
return 0;
}
后记:三种做法在 \(\text{CF}\) 上的用时分别为 \(124,1060,3541\)(单位 \(\text{ms}\)),可见实际运行效率还是有一定的差距。