多项式乘法模板(FFT)
题目链接
递归实现
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <iostream>
const int N = 4e6+10;
const double pi = acos(-1.0);
using namespace std;
typedef long long ll;
struct complex{
double x,y;
complex(double xx=0,double yy=0){x=xx,y=yy;}
complex operator +(const complex &a){return complex(x+a.x,y+a.y);}
complex operator -(const complex &a){return complex(x-a.x,y-a.y);}
complex operator *(const complex &a){return complex(x*a.x-y*a.y,x*a.y+y*a.x);}
};
void fft(complex* a,int n,int type){
if(n==1) return;
complex a1[n+1],a2[n+1];
for(int i=0;i<=n;i+=2)
a1[i>>1]=a[i],a2[i>>1]=a[i+1];
fft(a1,n>>1,type);
fft(a2,n>>1,type);
complex Wn=complex(cos(2.0*pi/n),type*sin(2.0*pi/n)),w = complex(1,0);
for(int i=0;i<(n>>1);i++,w=w*Wn){
a[i]=a1[i]+w*a2[i];
a[i+(n>>1)]=a1[i]-w*a2[i];
}
}
complex a[N],b[N];
int main(){
int n,m,x;
scanf("%d",&n);
scanf("%d",&m);
for(int i=0;i<=n;i++) scanf("%d",&x),a[i].x = x,a[i].y=0;
for(int i=0;i<=m;i++) scanf("%d",&x),b[i].x = x,b[i].y=0;
int len = 1;
while(len <= (n+m)) len *= 2;
fft(a,len,1);
fft(b,len,1);
for(int i=0;i<=len;i++){
a[i] = a[i] * b[i];
}
fft(a,len,-1);
for(int i=0;i<=(n+m);i++) printf("%0.lf ",fabs(a[i].x/(len)));
return 0;
}
迭代实现
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <iostream>
const int N = 4e6+10;
const double pi = acos(-1.0);
using namespace std;
typedef long long ll;
struct complex{
double x,y;
complex(double xx=0,double yy=0){x=xx,y=yy;}
complex operator +(const complex &a){return complex(x+a.x,y+a.y);}
complex operator -(const complex &a){return complex(x-a.x,y-a.y);}
complex operator *(const complex &a){return complex(x*a.x-y*a.y,x*a.y+y*a.x);}
};
int r[N];
complex a[N],b[N];
void fft(complex* a,int n,int type){
for(int i=0;i<n;i++){
if(i<r[i])
swap(a[i],a[r[i]]);
}
for(int mid=1;mid<n;mid<<=1){
complex Wn=complex(cos(pi/mid),type*sin(pi/mid));
for(int R=mid<<1,j=0;j<n;j+=R){
complex w(1,0);
for(int k=0;k<mid;k++,w=w*Wn){
complex x=a[j+k],y=w*a[j+mid+k];
a[j+k]=x+y;
a[j+mid+k]=x-y;
}
}
}
}
int main(){
int n,m,x;
scanf("%d",&n);
scanf("%d",&m);
for(int i=0;i<=n;i++) scanf("%d",&x),a[i].x = x,a[i].y=0;
for(int i=0;i<=m;i++) scanf("%d",&x),b[i].x = x,b[i].y=0;
int len = 1,l=0;
while(len <= (n+m)) len *= 2,l++;
for(int i=0;i<len;i++)
r[i] =(r[i>>1]>>1)|((i&1)<<(l-1));
fft(a,len,1);
fft(b,len,1);
for(int i=0;i<=len;i++){
a[i] = a[i] * b[i];
}
fft(a,len,-1);
for(int i=0;i<=(n+m);i++) printf("%d ",int(a[i].x/(len)+0.5));
return 0;
}