【模板】任意模数NTT

题目描述:

luogu

题解:

用$fft$水过(什么$ntt$我不知道)。

众所周知,$fft$精度低,$ntt$处理范围小。

所以就有了任意模数ntt神奇$fft$!

意思是这样的。比如我要算$F*G$,我可以把这两个多项式各分成两个多项式,一个表示$F_x/M$,一个表示$F_x$%$M$($M$是自己设定的阈值)。

比如说$F=a*M+b,G=c*M+d$,那么$F*G=(a*M+b)*(c*M+d)=a*c*M^2+a*d*M+b*c*M+b*d$。

然后?就水过了啊……

顺便提一下,要开$long double$。

代码:

#include<cmath>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N = ;
const long double Pi = acos(-1.0);
template<typename T>
inline void read(T&x)
{
T f = ,c = ;char ch=getchar();
while(ch<''||ch>''){if(ch=='-')f=-;ch=getchar();}
while(ch>=''&&ch<=''){c=c*+ch-'';ch=getchar();}
x = f*c;
}
int n,m,MOD;
struct cp
{
long double x,y;
cp(){}
cp(long double x,long double y):x(x),y(y){}
cp operator + (const cp&a)const{return cp(x+a.x,y+a.y);}
cp operator - (const cp&a)const{return cp(x-a.x,y-a.y);}
cp operator * (const cp&a)const{return cp(x*a.x-y*a.y,x*a.y+y*a.x);}
};
int to[N],lim,L;
void init()
{
lim = ,L = ;
while(lim<=*max(n,m))lim<<=,L++;
for(int i=;i<lim;i++)
to[i] = ((to[i>>]>>)|((i&)<<(L-)));
}
ll A[N],B[N],C[N];
void fft(cp*a,int len,int k)
{
for(int i=;i<len;i++)
if(i<to[i])swap(a[i],a[to[i]]);
for(int i=;i<len;i<<=)
{
cp w0(cos(Pi/i),k*sin(Pi/i));
for(int j=;j<len;j+=(i<<))
{
cp w(,);
for(int o=;o<i;o++,w=w*w0)
{
cp w1 = a[j+o],w2 = a[j+o+i]*w;
a[j+o] = w1+w2;
a[j+o+i] = w1-w2;
}
}
}
if(k==-)
for(int i=;i<len;i++)a[i].x/=len;
}
cp a[N],b[N],c[N],d[N],e[N],f[N],g[N],h[N];
void mtt()
{
int M = ;
for(int i=;i<max(n,m);i++)
{
a[i].x = A[i]/M,b[i].x = A[i]%M;
c[i].x = B[i]/M,d[i].x = B[i]%M;
}
fft(a,lim,),fft(b,lim,),fft(c,lim,),fft(d,lim,);
for(int i=;i<lim;i++)
{
e[i] = a[i]*c[i],f[i] = a[i]*d[i];
g[i] = b[i]*c[i],h[i] = b[i]*d[i];
}
fft(e,lim,-),fft(f,lim,-),fft(g,lim,-),fft(h,lim,-);
for(int i=;i<lim;i++)
C[i] = (((ll)(e[i].x+0.1))%MOD*M%MOD*M%MOD+((ll)(f[i].x+0.1))%MOD*M%MOD
+((ll)(g[i].x+0.1))%MOD*M%MOD+((ll)(h[i].x+0.1))%MOD)%MOD;
}
int main()
{
read(n),read(m),read(MOD);n++,m++;
init();
for(int i=;i<n;i++)read(A[i]);
for(int i=;i<m;i++)read(B[i]);
mtt();
for(int i=;i<n+m-;i++)printf("%lld ",C[i]);
puts("");
return ;
}
上一篇:看起来像一个输入框的input,实际上是有两个input


下一篇:Redis批量导入数据的方法