FFT 学习笔记
前置知识
-
代数基本定理:\(n\) 次函数最多与 \(x\) 轴有 \(n\) 个交点
-
系数表示法:字面意思,用 \(x^0\sim x^n\) 处的一连串系数表示这个多项式
-
点值表示法:把多项式视为一个 \(n\) 次函数,取函数图像上的 \(n+1\) 个点 \((x_0,y_0),(x_1,y_1),…,(x_n,y_n)\)表示这个函数(前提是忽略倍数点之类的无用点)
-
\(n\) 次单位根:学过复数最好,没学过只需要记住它是一类特殊的数(实质上是把复平面上的单位圆均分成 \(n\) 份后得到的 \(n\) 个点),记做 \(\omega_n\),满足以下几条性质:\(\omega_n^n=\omega_n^1=1(\omega_{2n}^n=-1),\omega_{2n}^{2k}=\omega_n^k,w_{2n}^{k+n}=-w_{2n}^k,w_{n}^{k}=w_{n}^{k\%n}\)
Q:为什么取 \(n+1\) 个点就能唯一表示一个 \(n\) 次函数?
A:有两种角度。
第一种是高斯消元,\(n\) 个方程式可以解出 \(n\) 个未知数,所以取 \(n+1\) 个点相当于给定了 \(n+1\) 个关于 \(a_0,a_1,…,a_n\) 的方程,可以唯一地解出这 \(n+1\) 个系数。
第二种是反证法。假设有两个次数小于等于 \(n\) 的多项式 \(f(x),g(x)\) 满足其对应的函数有 \(n+1\) 个不同的交点,令 \(h(x)=f(x)-g(x)\),那么 \(h(x)\) 的次数一定小于等于 \(n\),并且其与 \(x\) 轴有 \(n+1\) 个交点。由代数学基本定理,这是不可能的,矛盾!所以在给定 \(n+1\) 个点后只有一个次数小于等于 \(n\) 的多项式满足条件。
FFT 原理初探
以下涉及到的运算如无特殊说明均为复数域上的运算。
记 \(h(x)=f(x)\cdot g(x)\)。如果 \(f,g\) 用的是同一 \(x\) 序列得到的点值表示法,那么直接将对应点值相乘即可得到 \(h\) 的点值表示,这样做的复杂度是 \(O(n)\) 的。
这启发我们先把 \(f,g\) 由系数表达转化为点值表达,得到 \(h\) 的点值表达后再转为系数表达。系数表达转点值表达,随便选一组 \(x\) 代入是 \(n^2\) 的,不能接受。
还记得前置知识里的 \(n\) 次单位根吗?它拥有很良好的性质,我们不妨尝试把所有的 \(n\) 次单位根即 \(w_n^0,w_n^1,…,w_n^{n-1}\) 代入 \(f\) 中求出点值表达,这个过程叫作 DFT。在此我们假定 \(n\) 是 \(2\) 的整次幂,如果 \(n\) 不足就补到最近的幂,还有个小细节是如果 \(n\) 本身就是 2 的整次幂,仍需把 \(n\) 补到 \(2n\),因为 \(n\) 次函数需要 \(n+1\) 个点才能唯一确定。
DFT 中,点值表达的第 \(k\) 项 \(y_k=\sum_{i=0}^{n-1}a_i(\omega_n^k)^i\),按照 \(i\) 的奇偶性拆开,\(y_k=\sum_{i=0}^{\frac{n}{2}-1}a_{2i}(\omega_n^k)^{2i}+\sum_{i=0}^{\frac{n}{2}-1}a_{2i+1}(\omega_n^k)^{2i+1}\) 。注意到 \((\omega_n^k)^{2i}=(\omega_n^{2k})^i=(\omega_\frac{n}{2}^k)^i,(\omega_n^k)^{2i+1} = \omega_n^k\cdot(\omega_n^{2k})^i = \omega_n^k(\omega_\frac{n}{2}^k)^i\),则 \(y_k=\sum_{i=0}^{\frac{n}{2}-1}a_{2i}(\omega_{\frac{n}{2}}^{k})^i+w_{n}^{k}\sum_{i=0}^{\frac{n}{2}-1}a_{2i+1}(\omega_{\frac{n}{2}}^{k})^{i}\)。如果把 \(a_{2i}\) 看做一个次数为 \(n\) 的多项式 \(F\) 的系数表达,\(a_{2i+1}\) 看做另一个次数为 \(n\) 的多项式 \(G\) 的系数表达,前一个和式其实就是 DFT(F) 的第 \(k\% \frac{n}{2}\) 项,后一个和式则为 DFT(G) 的第 \(k\%\frac{n}{2}\) 项。所以我们只需要递归求出 DFT(F) 及 DFT(G) 就可以 \(O(n)\) 求出 DFT(F)。这实际上是一个分治的结构,在递归到 \(n=1\) 时,DFT(f) 就等于 \(f\) 系数表达中唯一的一项 \(a_0\)(这是由于 \(w_n^0=1\)),直接返回,然后自底向上合并,总时间复杂度显然为 \(O(n\log n)\)。
现在我们能够在 \(O(n\log n)\) 的时间内求出 DFT(h),而由 DFT(h) 反推出 \(h\) 的过程叫作 idft。这里有个高明的结论是 h 的系数表达等于把 dft(h) 的结果当做另一个多项式的系数表达再做一遍 dft 然后把得到的序列除以 \(n\),只不过这次 dft 代入的是 \(w_n^{-i}\) 而非 \(\omega_n^i\),即
\[n\times h_i=\sum_{j=0}^{n-1}(\omega_n^{-i})^j\cdot dft_j(h)=\sum_{j=0}^{n-1}(\omega_n^{-i})^j \sum_{k=0}^{n-1}h_k(\omega_n^j)^k=\sum_{0\le j,k<n}h_k\omega_n^{j(k-i)} \]对于每个 \(k\),如果 \(k\not=i\),那么 \(\omega_n^{j(k-i)}\) 构成了一个公比为 \(w_n^{k-i}\) 的等比数列,其首项为 \(\omega_n^0=1\),末项为 \((\omega_n^{k-i})^{n-1}\),由等比数列求和公式知 \(\sum_{0\le j<n}(\omega_n^{k-i})^j=\frac{(\omega_n^{k-i})^n-1}{\omega_n^{k-i}-1}=0\)。而如果 \(k=i\),其贡献为 \(nh_k=nh_i\),恰等于左式,证毕。所以我们可以把 dft 和 idft 放在一个函数里写:
#define db double
const int N=3e6+5;
const db Pi=acos(-1);
struct cp{
db x,y;//重载复数类
cp(){}
cp(db a,db b){x=a,y=b;}
}w[N],tmp[N];
cp operator+(const cp &x,const cp &y){return cp(x.x+y.x,x.y+y.y);}
cp operator-(const cp &x,const cp &y){return cp(x.x-y.x,x.y-y.y);}
cp operator*(const cp &a,const cp &b){return cp(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
cp f[N],g[N],h[N],a[N];
int n,m;
void dft(cp *f,int n,int flg){
if(n==1) return;
fo(i,0,n-1) tmp[i]=f[i];
cp *L=f,*R=f+n/2;
fo(i,0,n/2-1) L[i]=tmp[i<<1];
fo(i,0,n/2-1) R[i]=tmp[i<<1|1];
dft(L,n/2,flg);
dft(R,n/2,flg);
cp sum=cp(1,0),I=cp(cos(2*Pi/n),flg*sin(2*Pi/n));
fo(i,0,n-1){
w[i]=sum;
sum=sum*I;//这里是精度误差的重灾区
}
fo(i,0,n/2-1){
cp x=L[i],y=R[i]*w[i];
L[i]=x+y,R[i]=x-y;
}
}
signed main(){
cin>>n>>m;int lim=1;
while(lim<=n+m) lim<<=1;
fo(i,0,n) f[i]=cp(read(),0);
fo(i,0,m) g[i]=cp(read(),0);
dft(f,lim,1);
dft(g,lim,1);
fo(i,0,lim-1) h[i]=f[i]*g[i];
dft(h,lim,-1);
fo(i,0,n+m) cout<<(int)(h[i].x/lim+0.49)<<' ';
//FFT 精度丢失很严重,你必须四舍五入才能保证正确性……
return 0;
}
常数优化
蝴蝶变换
上述代码在实际题目中的表现极差,因为它采用递归实现,并且进行了大量的内存拷贝。细想一下,我们不断向下递归其实只是为了把系数序列按照奇偶性划分成两边。而在回溯时我们只是把这两边合并起来,并没有进行内存的拷贝等操作。如果能预先处理出递归到最底层时序列被划分成的模样,是不是就可以避免递归和内存拷贝了?
直接抛结论:最终得到的序列的第 \(i\) 项等于原序列的第 \(rev_i\) 项,其中 \(rev_i\) 表示把 \(i\) 的二进制翻转得到的数,如 \(rev_6=3(rev_{110}=011)\)。为什么?考虑最普通的序列分治(如归并排序)其实是按二进制下的最高位分成两边,最终它得到的还是原序列,因为原序列本来就是按照二进制最高位的顺序排列的。而在 dft 中,按奇偶性分类实则是按二进制下的最低位分成两边,等价于按照最低位的大小排序,最终得到的序列就是按 \(rev_i\) 从小到大排序的结果!因此我们可以提前预处理出 \(rev\) 数组,然后把序列按 \(rev\) 数组变换为它在原本递归到最底层时的样子,自底向上合并,这个过程被称为蝴蝶变换(蝶形优化)。你发现这样做连递归都省了,代码也更为短小精悍:
void fft(cp *f,int n,int flg){
fo(i,0,n-1) if(i<rev[i]) swap(f[i],f[rev[i]]);
w[0]=cp(1,0);
for(int j=2;j<=n;j<<=1){
cp I=cp(cos(2*Pi/j),flg*sin(2*Pi/j));
go(i,j-1,0){
if(i&1) w[i]=w[i>>1]*I;
else w[i]=w[i>>1];
}
for(int i=0;i<n;i+=j){
fo(k,i,i+j/2-1){
cp qwq=w[k-i]*f[k+j/2];
f[k+j/2]=f[k]-qwq;
f[k]=f[k]+qwq;
}
}
}
}
signed main(){
cin>>n>>m;lim=1;
fo(i,0,n) f[i]=cp(read(),0);
fo(i,0,m) g[i]=cp(read(),0);
while(lim<=n+m) lim<<=1;
fo(i,0,lim-1) rev[i]=(rev[i>>1]>>1)|(i&1?(lim>>1):0);//你品,你细品
fft(f,lim,1);
fft(g,lim,1);
fo(i,0,lim-1) h[i]=f[i]*g[i];
fft(h,lim,-1);
fo(i,0,n+m) cout<<(int)(h[i].x/lim+0.49)<<' ';
return 0;
}
还有个小 trick 是把单位根也一块预处理出来放在同一个数组里,不再赘述
三次变两次
一个不容忽视的事实是 \(f,g,h\) 这三个多项式的系数都只有实部没有虚部,在 fft 时却不得不当做完整的复数来运算,这难免会造成一些计算上的冗余。
考虑把 \(g\) 塞到 \(f\) 的虚部上去,令 \(P(x)=f(x)+g(x)i\),则 \(P^2(x)=f^2(x)-g^2(x)+2f(x)g(x)i\),因此 \(P^2(x)\) 的虚部除以 \(2\) 就是 \(f(x)g(x)\)。所以我们先算出 \(dft(P)\),然后用它乘自己得到 \(dft(P^2)\),再 idft 回去即可。这样只需要做两次 \(dft\)。
cin>>n>>m;lim=1;
fo(i,0,n) f[i].x=read();
fo(i,0,m) f[i].y=read();
while(lim<=n+m) lim<<=1;
fo(i,0,lim-1) rev[i]=(rev[i>>1]>>1)|(i&1?(lim>>1):0);
//out(rev,0,lim-1);
fft(f,lim,1);
fo(i,0,lim-1) h[i]=f[i]*f[i];
fft(h,lim,-1);
fo(i,0,n+m) cout<<(int)(h[i].y/2/lim+0.49)<<' ';