[bzoj3992][SDOI2015]序列统计——离散对数+NTT

题目大意:

给定一个数字不超过\(m\)的集合\(S\),用\(S\)中的数生成一个长度为\(n\)的序列,求所有序列中的元素乘积模\(m\)等于\(x\)的序列的个数。

思路:

考虑最朴素的\(DP\),设\(f_{i,j}\)为选了\(i\)个数,乘积模\(m\)余\(j\)的方案数,直接转移的时间复杂度是\(O(nm^2)\)的。

不难发现每次转移的过程是相同的,矩阵加速显然不太可行,考虑将乘法形式的转移变成加法形式的转移,这样每次转移即可用NTT优化。

这里需要用到一个叫做离散对数的东西,即在取模的意义下,将每个\(m\)以内的数都表示为\(g^x\)幂的形式,这里的\(g\)为模\(m\)意义下的原根。

这样我们将每个数对\(g\)取对数之后,每次转移便可以用NTT来优化了,但是\(n\)很大还是个问题,这个时候发现多项式乘法也是满足结合律的,既然每次的转移多项式是一样的,直接上快速幂即可。

 
/*=======================================
 * Author : ylsoi
 * Time : 2019.2.4
 * Problem : bzoj3992
 * E-mail : ylsoi@foxmail.com
 * ====================================*/
#include<bits/stdc++.h>
 
#define REP(i,a,b) for(int i=a,i##_end_=b;i<=i##_end_;++i)
#define DREP(i,a,b) for(int i=a,i##_end_=b;i>=i##_end_;--i)
#define debug(x) cout<<#x<<"="<<x<<" "
#define fi first
#define se second
#define mk make_pair
#define pb push_back
typedef long long ll;
 
using namespace std;
 
void File(){
    freopen("bzoj3992.in","r",stdin);
    freopen("bzoj3992.out","w",stdout);
}
 
template<typename T>void read(T &_){
    _=0; T fl=1; char ch=getchar();
    for(;!isdigit(ch);ch=getchar())if(ch=='-')fl=-1;
    for(;isdigit(ch);ch=getchar())_=(_<<1)+(_<<3)+(ch^'0');
    _*=fl;
}
 
const int maxm=8000+10;
const int mod=1004535809;
int n,m,aim,sz,t[maxm];
bool s[maxm];
 
ll qpow(ll x,ll y){
    ll ret=1; x%=mod;
    while(y){
        if(y&1)ret=ret*x%mod;
        x=x*x%mod;
        y>>=1;
    }
    return ret;
}
 
int lim,cnt,dn[maxm<<2];
ll g[maxm<<2],ig[maxm<<2];
 
void ntt(ll *A,int ty){
    REP(i,0,lim-1)if(i<dn[i])swap(A[i],A[dn[i]]);
    for(int len=1;len<lim;len<<=1){
        ll w= ty==1 ? g[len<<1] : ig[len<<1];
        for(int L=0;L<lim;L+=len<<1){
            ll wk=1;
            REP(i,L,L+len-1){
                ll u=A[i],v=A[i+len]*wk%mod;
                A[i]=(u+v)%mod;
                A[i+len]=(u-v)%mod;
                wk=wk*w%mod;
            }
        }
    }
    if(ty==-1){
        ll inv=qpow(lim,mod-2);
        REP(i,0,lim-1)A[i]=A[i]*inv%mod;
        REP(i,m-1,lim-1){
            A[i%(m-1)]=(A[i%(m-1)]+A[i])%mod;
            A[i]=0;
        }
    }
}
 
void init(){
    read(n),read(m),read(aim),read(sz);
    int x;
    REP(i,1,sz)read(x),s[x]=1;
    REP(i,2,m-1){
        x=i;
        ll w=x;
        REP(j,1,m-2){
            if(w==1){
                x=-1;
                break;
            }
            w=w*x%m;
        }
        if(x==i)break;
    }
    for(ll i=1,j=0;j<m-1;i=i*x%m,++j)
        t[i]=j;
    lim=1,cnt=0;
    while(lim<=m+m)lim<<=1,++cnt;
    if(!cnt)cnt=1;
    REP(i,0,lim-1)dn[i]=dn[i>>1]>>1|((i&1)<<(cnt-1));
    g[lim]=qpow(3,(mod-1)/lim);
    ig[lim]=qpow(g[lim],mod-2);
    for(int i=lim>>1;i;i>>=1){
        g[i]=g[i<<1]*g[i<<1]%mod;
        ig[i]=ig[i<<1]*ig[i<<1]%mod;
    }
}
 
ll a[maxm<<2],b[maxm<<2];
 
void work(){
    a[0]=1;
    REP(i,1,m-1)if(s[i])b[t[i]]=1;
    while(n){
        ntt(b,1);
        if(n&1){
            ntt(a,1);
            REP(i,0,lim-1)a[i]=a[i]*b[i]%mod;
            ntt(a,-1);
        }
        REP(i,0,lim-1)b[i]=b[i]*b[i]%mod;
        ntt(b,-1);
        n>>=1;
    }
    printf("%lld\n",(a[t[aim]]+mod)%mod);
}
 
int main(){
    //File();
    init();
    work();
    return 0;
}

上一篇:NTT学习笔记


下一篇:NTT学习笔记