多项式--MTT

MTT是什么?

看这样一道例题:

P4245 【模板】任意模数多项式乘法 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)

这道题有2个解法,一种就是MTT(基于FFT),一种就是任意模数NTT。

今天主要介绍MTT。

一个多项式:

设A1(X)=A(X)/M.B1(X)=A(X)%M

那么,A(X)=A1(X)*M+B1(X)。

为什么要这么转换,因为如果直接A(X)求的话,值域太大了,会爆掉,所以我们转换一下。

这里的M最好取2^15。

于是

A(X)*B(X)

=(A1(X)*M+B1(X))*(A2(X)*M+B2(X))

=A1(X)*A2(X)*M*M+B1(X)*A2(X)*M+B2(X)*A1(X)*M+B1(X)*B2(X)

因为最后总是要取模,所以这4项可以分开处理。

也就是说,要跑8次FFT,常数有点大,但是这道题足以卡过去。

这里还要提升一下精度为long double,如果还是卡不过去,单位根可以预处理一下(但我没处理卡过去了)。

多项式--MTT
#include<iostream>
#include<cstdio>
#include<cmath>
#define ll long long
using namespace std;
const int maxn=1000010;
int n,m,p,r[maxn],ans[maxn];
int block=32768;
const long double pi=acos(-1.0);
struct complex
{
    long double x,y;
    complex (long double xx=0,long double yy=0)
    {
        x=xx;
        y=yy;
    }
}a1[maxn],b1[maxn],a2[maxn],b2[maxn],x[maxn];
int limit=1,l;
complex operator + (complex a,complex b) {return complex(a.x+b.x,a.y+b.y);}
complex operator - (complex a,complex b) {return complex(a.x-b.x,a.y-b.y);}
complex operator * (complex a,complex b) {return complex(a.x*b.x-a.y*b.y,a.y*b.x+b.y*a.x);}
inline void FFT(complex *f,int type)
{
    for(int i=0;i<limit;i++) 
    if(i<r[i]) swap(f[i],f[r[i]]);
    for(int mid=1;mid<limit;mid<<=1)
    {
        complex Wn (cos(pi/mid),type*sin(pi/mid));
        for(int r=mid<<1,j=0;j<limit;j+=r)
        {
            complex w(1,0);
            for(int k=0;k<mid;k++,w=w*Wn)
            {
                complex x=f[k+j],y=w*f[k+j+mid];
                f[k+j]=x+y;
                f[k+j+mid]=x-y;
            }
        }
    }
}
inline void solve(complex *a,complex *b,int res)
{
    for(int i=0;i<limit;i++) x[i]=a[i]*b[i];
    FFT(x,-1);
    for(int i=0;i<=n+m;i++) (ans[i]+=(ll)(x[i].x/limit+0.5)%p*res%p)%=p;
}
inline void MTT(complex *a,complex *b,complex *c,complex *d)
{
    FFT(a,1);
    FFT(b,1);
    FFT(c,1);
    FFT(d,1);
    solve(a,c,block*block%p);
    solve(a,d,block%p);
    solve(c,b,block%p);
    solve(b,d,1);
}
int main()
{   
    cin>>n>>m>>p;
    for(int i=0;i<=n;i++) 
    {
        int x;
        cin>>x;
        a1[i].x=x/block;
        b1[i].x=x%block;
    }
    for(int i=0;i<=m;i++)
    {
        int x;
        cin>>x;
        a2[i].x=x/block;
        b2[i].x=x%block;
    }
    while(limit<=n+m) limit<<=1,l++;
    for(int i=0;i<limit;i++)
    {
        r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
    }
    MTT(a1,b1,a2,b2);
    for(int i=0;i<=n+m;i++)
    {
        cout<<ans[i]<<' ';
    }
    cout<<endl;
    return 0;
}
View Code
上一篇:6:统计属性


下一篇:Java的基本类型转换