NTT学习笔记

NTT就是模p意义下的FFT

前置知识

1.原根:
\[ g^{p-1}\equiv1\ \ (mod\ p)\\ g^0,g^1,g^2...g^{p-1}互不相同\\ 那么当p=k*2^n+1时\\ 设g_n=g^k\\ 则有\\ g_n^n=1\ \ (mod\ p)\\ g_n^{\frac{n}{2}}=-1\ \ (mod\ p)\\ \therefore g在模p的意义下与\omega等价\\ \]
2.中国剩余定理合并
\[ 给出\\ x\equiv a_1\ \ (mod\ p_1)\\ x\equiv a_2\ \ (mod\ p_2)\\ ...\\ x\equiv a_n\ \ (mod\ p_n)\\ 求一个特解\\ 设M=\Pi_{i=1}^np_i\\ 设M_i=\frac{M}{p_i}\\ 设t_i满足M_it_i\equiv1\ \ (mod\ p_i)\\ \therefore t_i\equiv M_i^{-1}\ \ (mod\ p_i)\\ 则有特解x_0=\Sigma(a_it_iM_i)\ mod\ M \]

NTT

如果\(p=k*2^n+1\)就可以直接搞了

但是如果\(p\ne k*2^n+1​\)呢

于是我们可以用三模数求解然后用中国剩余定理合并

洛谷P4245 【模板】任意模数NTT

#include<bits/stdc++.h>
#define N 400005
#define ll long long
#define pi 3.1415926535
using namespace std;
int Pow(int x,int y,int p){
    int re=1;
    while(y){
        if(y&1)re=1ll*re*x%p;
        x=1ll*x*x%p;
        y>>=1;
    }
    return re%p;
}
ll Mul(ll a,ll b,ll mod){
    a%=mod,b%=mod;
    return ((a*b-(ll)((ll)((long double)a/mod*b+1e-3)*mod))%mod+mod)%mod;
}
int rota[N],cnt;
void pre(int n,int m){
    int high=0;
    cnt=1;
    while(cnt<=n+m)cnt<<=1,high++;
    for(int i=0;i<cnt;i++)
        rota[i]=(rota[i>>1]>>1)|((i&1)<<(high-1));
}
void ntt(int lim,int *buf,int dft,int mod){
    for(int i=0;i<lim;i++)if(i<rota[i])swap(buf[i],buf[rota[i]]);
    for(int len=2;len<=lim;len<<=1){
        int g1=Pow(3,(mod-1)/len,mod);
        if(dft==-1)g1=Pow(g1,mod-2,mod);
        for(int s=0;s<lim;s+=len){
            int w=1;
            for(int k=s;k<s+len/2;k++,w=1ll*w*g1%mod){
                int x=buf[k],y=1ll*w*buf[k+len/2]%mod;
                buf[k]=x+y;
                buf[k+len/2]=x-y;
                buf[k]=(buf[k]%mod+mod)%mod;
                buf[k+len/2]=(buf[k+len/2]%mod+mod)%mod;
            }
        }
    }
    if(dft==-1){
        int inv=Pow(lim,mod-2,mod);
        for(int i=0;i<lim;i++)buf[i]=1ll*inv*buf[i]%mod;
    }
}
int n,m,mod;
int a[3][N],b[3][N],c[3][N],p[3]={469762049,998244353,1004535809};
signed main(){
    scanf("%d%d%d",&n,&m,&mod);
    for(int i=0,x;i<=n;i++){
        scanf("%d",&x);
        a[0][i]=a[1][i]=a[2][i]=x;
        for(int j=0;j<3;j++)a[j][i]%=p[j];
    }
    for(int i=0,x;i<=m;i++){
        scanf("%d",&x);
        b[0][i]=b[1][i]=b[2][i]=x;
        for(int j=0;j<3;j++)b[j][i]%=p[j];
    }
    pre(n,m);
    for(int i=0;i<3;i++){
        ntt(cnt,a[i],1,p[i]);
        ntt(cnt,b[i],1,p[i]);
        for(int j=0;j<=cnt;j++)
            c[i][j]=1ll*a[i][j]*b[i][j]%p[i];
        ntt(cnt,c[i],-1,p[i]);
    }
    ll M=1ll*p[0]*p[1];
    for(int i=0;i<=n+m;i++){
        ll A=Mul(1ll*c[0][i]*p[1]%M,Pow(p[1]%p[0],p[0]-2,p[0]),M);
        A=(A+Mul(1ll*c[1][i]*p[0]%M,Pow(p[0]%p[1],p[1]-2,p[1]),M))%M;
        ll K=((c[2][i]-A)%p[2]+p[2])%p[2]*Pow(M%p[2],p[2]-2,p[2])%p[2];
        printf("%lld ",((K%mod)*(M%mod)%mod+A%mod)%mod);
    }
    return 0;
}
/*
5 8 28
19 32 0 182 99 95
77 54 15 3 98 66 21 20 38

*/
上一篇:[bzoj3992][SDOI2015]序列统计——离散对数+NTT


下一篇:C++11的value category(值类别)以及move semantics(移动语义)