[学习笔记] 拆系数 FFT

目录

0. 楔子

我们知道,\(\mathtt{FFT}\) 是用复数运算的,当我们需要取模且数据范围较大的时候就没有办法了。

比如 这道题

如果能减小数据范围,我们就可以先算再模,于是,拆系数 \(\mathtt{FFT}\) 就闪亮登场了!

1. 正文

注意:\(*\) 表示卷积。

首先计算一下本题的数据范围:\(10^9\times 10^9\times 10^5=10^{23}\)(两数相乘,加 \(n\) 遍)。

设 \(base\) 为一个 \(\sqrt p\) 级别的数,可以将 \(F(x),G(x)\) 分别分解成 \(A(x)\times base+B(x),C(x)\times base+D(x)\)(注意要除尽),这样拆出来的函数是 \(<\sqrt p\) 的。

\[F*G=(A\times base+B)*(C\times base+D) \]

\[=A*C\times base^2+(A*D+B*C)\times base+B*D \]

这样数据范围就是 \(10^{14}\) 左右。但是这样要做 \(7\) 次 \(\mathtt{FFT}\)(插值 \(4\) 次,转换 \(3\) 次)。

我们知道,\(\mathtt{FFT}\) 的虚部赋值为 \(0\),我们可以利用这个空间。

考虑计算 \(A,B,C,D\) 的点值表示。令

\[f(k)=A(k)+i\times B(k) \]

\[g(k)=A(k)-i\times B(k) \]

\[h(k)=C(k)+i\times D(k) \]

我们发现 \(f(k),g(n-k)\) 是共轭的(\(n\) 是 \(n\) 补全的 \(2\) 的幂)。

\[f(k)=A(k)+i\times B(k) \]

\[=\sum_{j=0}^{n-1}a_j(\omega_n^k)^j+i\times \sum_{j=0}^{n-1}b_j(\omega_n^k)^j \]

\[=\sum_{j=0}^{n-1}(a_j+i\times b_j)(\omega_n^k)^j \]

\[=\sum_{j=0}^{n-1}(a_j+i\times b_j)(\cos(\frac{2\pi kj}{n})+i\sin (\frac{2\pi kj}{n})) \]

\[=\sum_{j=0}^{n-1}(a_j\times \cos(\frac{2\pi kj}{n})-b_j\times \sin (\frac{2\pi kj}{n}))+i(b_j\times \cos(\frac{2\pi kj}{n})+a_j\times \sin (\frac{2\pi kj}{n})) \]

\[g(n-k)=A(n-k)-i\times B(n-k) \]

\[=\sum_{j=0}^{n-1}a_j(\omega_n^{-k})^j-i\times \sum_{j=0}^{n-1}b_j(\omega_n^{-k})^j \]

\[=\sum_{j=0}^{n-1}(a_j-i\times b_j)(\omega_n^{-k})^j \]

\[=\sum_{j=0}^{n-1}(a_j-i\times b_j)(\cos(\frac{2\pi kj}{n})-i\sin (\frac{2\pi kj}{n})) \]

\[=\sum_{j=0}^{n-1}(a_j\times \cos(\frac{2\pi kj}{n})-b_j\times \sin (\frac{2\pi kj}{n}))-i(b_j\times \cos(\frac{2\pi kj}{n})+a_j\times \sin (\frac{2\pi kj}{n})) \]

所以计算 \(f,g,h\) 的点值表达式只用 \(2\) 次 \(\mathtt{FFT}\)。

令 \(p=f*h,q=g*h\)。所以

\[p=A*C-B*D+i(A*D+B*C) \]

\[q=A*C+B*D+i(A*D-B*C) \]

左边的每一项就是右边卷完后每一项系数经过一番运算。

所以计算出 \(p,q\) 需要 \(2\) 次 \(\mathtt{FFT}\),将 \(p,q\) 对应项相加即可解出 \(A*C,A*D\),从而都解出来。

总共需要 \(4\) 次 \(\mathtt{FFT}\)。

2. 代码

用到了预处理单位根,这样精度会高一些。

#include <cstdio>

#define rep(i,_l,_r) for(register signed i=(_l),_end=(_r);i<=_end;++i)
#define fep(i,_l,_r) for(register signed i=(_l),_end=(_r);i>=_end;--i)
#define erep(i,u) for(signed i=head[u],v=to[i];i;i=nxt[i],v=to[i])
#define efep(i,u) for(signed i=Head[u],v=to[i];i;i=nxt[i],v=to[i])
#define print(x,y) write(x),putchar(y)

template <class T> inline T read(const T sample) {
    T x=0; int f=1; char s;
    while((s=getchar())>'9'||s<'0') if(s=='-') f=-1;
    while(s>='0'&&s<='9') x=(x<<1)+(x<<3)+(s^48),s=getchar();
    return x*f;
}
template <class T> inline void write(const T x) {
    if(x<0) return (void) (putchar('-'),write(-x));
    if(x>9) write(x/10);
    putchar(x%10^48);
}
template <class T> inline T Max(const T x,const T y) {if(x>y) return x; return y;}
template <class T> inline T Min(const T x,const T y) {if(x<y) return x; return y;}
template <class T> inline T fab(const T x) {return x>0?x:-x;}
template <class T> inline T gcd(const T x,const T y) {return y?gcd(y,x%y):x;}
template <class T> inline T lcm(const T x,const T y) {return x/gcd(x,y)*y;}
template <class T> inline T Swap(T &x,T &y) {x^=y^=x^=y;}

#include <cmath>
#include <iostream>
using namespace std;
typedef long long ll;

const double Pi=acos(-1.0);
const int num1=(1<<30),num2=(1<<15),maxn=262150;

int n,m,mod,rev[maxn],lim,bit;
ll a1b1,a1b2,a2b1,a2b2;
struct cp {
	double x,y;
	
	cp operator + (const cp t) const {return (cp){x+t.x,y+t.y};}
	cp operator - (const cp t) const {return (cp){x-t.x,y-t.y};}
	cp operator * (const cp t) const {return (cp){x*t.x-y*t.y,y*t.x+x*t.y};}
} f[maxn],g[maxn],h[maxn],w[maxn][2],tmp;

void FFT(cp *f,const int op=1) {
	rep(i,0,lim-1) if(i<rev[i]) swap(f[i],f[rev[i]]);
	for(int mid=1;mid<lim;mid<<=1) {
		for(int i=0,p=(mid<<1);i<lim;i+=p) {
			for(int j=0;j<mid;++j) {
				tmp=w[lim/mid/2*j][op==1]*f[i+j+mid];
				f[i+j+mid]=f[i+j]-tmp,f[i+j]=f[i+j]+tmp;
			}
		}
	}
}

void init() {
	lim=1;
	while(lim<=n+m) lim<<=1,++bit;
	rep(i,0,lim-1) {
		rev[i]=(rev[i>>1]>>1)|((i&1)<<bit-1);
		w[i][1]=(cp){cos(Pi*2*i/lim),sin(Pi*2*i/lim)};
		w[i][0]=(cp){w[i][1].x,-w[i][1].y};
	}
}

void MTT() {
	int val;
	rep(i,0,n) val=read(9),f[i]=(cp){val>>15,val&32767};
	rep(i,0,m) val=read(9),h[i]=(cp){val>>15,val&32767};
	init();
	FFT(f),FFT(h);
	g[0]=(cp){f[0].x,-f[0].y};
	rep(i,1,lim-1) g[i]=(cp){f[lim-i].x,-f[lim-i].y};
	// 先除以 lim,这样就只用除一次
	rep(i,0,lim-1) h[i].x/=lim,h[i].y/=lim,f[i]=f[i]*h[i],g[i]=g[i]*h[i];
	FFT(f,-1),FFT(g,-1);
	rep(i,0,n+m) {
		a1b1=(ll)((f[i].x+g[i].x)/2+0.5)%mod;
		a1b2=(ll)((f[i].y+g[i].y)/2+0.5)%mod;
		a2b1=((ll)(f[i].y+0.5)-a1b2)%mod;
		a2b2=((ll)(g[i].x+0.5)-a1b1)%mod;
		print((a1b1*num1%mod+(a1b2+a2b1)*num2%mod+a2b2)%mod,' ');
	}
	puts("");
}

int main() {
	n=read(9),m=read(9),mod=read(9);
	MTT();
	return 0;
} 
上一篇:【洛谷P4245】【模板】任意模数多项式乘法


下一篇:P4721 【模板】分治 FFT