第一和第二类斯特林数的学习

最近在学第一类和第二类斯特林数。这里记录一下学习的知识点/模板还有题目。

https://blog.csdn.net/litble/article/details/80882581

https://www.cnblogs.com/y2823774827y/p/10700231.html

https://www.cnblogs.com/cjyyb/p/10142878.html

第一类斯特林数是解决:将n个元素划分为k个圆排列的方案数,递推式为f(i,j)=f(i−1,j−1)+(i−1)f(i−1,j),用递推式求某个S1(n,k)的话时间是O(n^2)不太理想。

一种比较好的做法是根据第一类斯特林数的性质:x*(x+1)(x+2)(x+3)……(x+n-1)=Σf[n][i]*x^i ,可以用分治NTT求其生成函数,然后第i项的系数即是S1[n][i]。这样的时间是O(n*log2n^2)

其实还有O(n*log2n)的求法,我还没学。

例题+模板:codeforces 960G Bandit Blues

推式子后发现 ans=S1[n-1][A+B-2]*C(A+B-2,A-1) 。于是主要矛盾就是求S1(n-1,A+B-2)

#include<bits/stdc++.h>
#define LL long long
using namespace std;
const LL P=998244353,yg=3;
LL A[400010];
LL bin[400010];

LL power(LL x,LL p) {
    LL ret=1;
    for (;p;p>>=1) {
        if (p&1) ret=(ret*x)%P;
        x=(x*x)%P;
    }
    return ret;
}

void NTT(LL *a,LL n,LL op) {  //NTT:系数a数组,长度为n,op=1求值op=-1插值 
    for(LL i=0;i<n;i++) bin[i]=(bin[i>>1]>>1)|((i&1)*(n>>1));
    for(LL i=0;i<n;i++) if(i<bin[i]) swap(a[i],a[bin[i]]);
    for(LL i=1;i<n;i<<=1) {
        LL wn=power(yg,op==1?(P-1)/(2*i):(P-1)-(P-1)/(2*i)),w,t;
        for(LL j=0;j<n;j+=i<<1) {
            w=1;
            for(LL k=0;k<i;k++) {
                t=a[i+j+k]*w%P;w=w*wn%P;
                a[i+j+k]=(a[j+k]-t+P)%P;a[j+k]=(a[j+k]+t)%P;
            }
        }
    }
    if(op==-1) {
        LL Inv=power(n,P-2);
        for(LL i=0;i<n;i++) a[i]=a[i]*Inv%P;
    }
}

LL n,a,b;
void solve(LL *a,LL len,LL l,LL r) {  //分治NTT求第一类斯特林数 
    if(l==r) {a[0]=l;a[1]=1;return;}  //分治边界 
    LL mid=(l+r)/2; LL a1[len],a2[len];
    memset(a1,0,sizeof(a1));memset(a2,0,sizeof(a2));
    solve(a1,len>>1,l,mid);solve(a2,len>>1,mid+1,r);  //分治,先求两边 
    NTT(a1,len,1);NTT(a2,len,1);
    for(LL i=0;i<len;i++) a[i]=a1[i]*a2[i]%P;  //两边NTT结果相乘得到[l,r]的结果 
    NTT(a,len,-1);
}

LL C(LL m,LL n) {  //求组合数C(m,n) 
    LL fac1=1,fac2=1;
    for(LL i=1;i<=n;i++) (fac1*=i)%=P,(fac2*=(m-i+1))%=P;
    return fac2*power(fac1,P-2)%P;
}

int main()
{
    scanf("%lld%lld%lld",&n,&a,&b);
    if(a+b-2>n-1||!a||!b) return puts("0"),0;
    if(n==1) return puts("1"),0;
    
    LL N=n-1,M=a+b-2;
    LL len=1;while(len<(n+1)<<1) len<<=1;
    solve(A,len,0,N-1);  //求S1(n-1,i)这一行的值 
    
    printf("%lld",A[M]*C(a+b-2,a-1)%P);
}

 

上一篇:【模板】分治FFT


下一篇:CS Academy Round 75 Permutations NTT