概述
FFT,即 快速傅里叶变换 ,是将多项式乘法从 \(O(n^2)\) 优化到 \(O(n\log n)\) 的算法。
本质上是优化卷积,卷积的一般形式:
\[C(i)=\sum\limits_{i\oplus j=k}A(i)B(i) \]其中多项式乘法为加法卷积,即:
\[C(i)=\sum\limits_{i+j=k}A(i)B(i) \]系数表示法:
我们可以用每个项的系数来表示这个多项式:
\[f(x)=a_nx^n+\dots +a_2x^2+a_1x+a_0\Rightarrow f(x) = \{a_0,a_1,\dots,a_n\} \]点值表示法:
我们可以把多项式看成一个函数,那么从平面上取 \(n+1\) 个点,则可以确定一个 \(n+1\) 项的函数。
原因可以理解成 \(n+1\) 个多项式,也就是 \(n+1\) 个方程,可以解出 \(n+1\) 元的方程。
FFT 的原理就是将对于每一个多项式 \(A(x)\) 和 \(B(x)\) 由系数表示法变成点值表示法,\(A(x)\) 和 \(B(x)\) 相乘得 \(C(x)\) ,再变成系数表示法。将系数表示法变成点值表示法的过程称为 DFT,点值表示法变成系数表示法的过程称为 IDFT。
但如果暴力算的话,\(A(x)\) 和 \(B(x)\) 变成点值表示法是 \(O(n^2)\) 的,再用高斯消元求解是 \(O(n^3)\) ,所以 FFT 可以加速这个过程。
单位根
对于点值表示法的 \(n\) 个 \(A(x)\) 和 \(B(x)\) ,需要快速算出它们的值。
直接算 \(O(n^2)\) 的,发现对于 \(x^n=1\) 的 \(x\) ,可以 \(O(n)\) 计算出 \(A(x)\) 和 \(B(x)\) 。
在实数域中,只有 \(1\) 和 \(-1\) 满足 \(O(n)\) 算出多项式。
在复数域中,则有 \(i\) 和 \(-i\) 满足要求。
考虑复数乘法,表示为 模长相乘,辐角相加 ,则两个模长为 \(1\) 的向量相乘,得到的还是模长为 \(1\) 的向量。
可以定义 \(x^n=1\) 在复数意义下的解为 \(n\) 次复根,即 \(\omega_{n}\) 。这样的复根有 \(n\) 个,表示为 \(\omega_{n}^k\) ,其中 \(k=0,1,\dots,n-1\)
OI-WIKI上的图:
性质:
\[\omega_{n}^k=\omega_{2n}^{2k} \\ \omega_{2n}^{n+k}=-\omega_{2n}^k \]快速傅里叶变换:
FFT 的基本思想是 分治 。
DFT
我们将多项式分为奇次项和偶次项处理。
对于一个 \(8\) 项多项式,按照次数分为两组:
\[f(x)=(a_0+a_2x^2+a_4x^4+a_6x^6)+(a_1x+a_3x^3+a_5x^5+a_7x^7)\\=(a_0+a_2x^2+a_4x^4+a_6x^6)+x(a_1+a_3x^2+a_5x^4+a_7x^6) \] \[G(x)=a_0+a_2x+a_4x^2+a_6x^3\\ H(x)=a_1+a_3x+a_5x^2+a_7x^3 \]则有:
\[f(x)=G(x^2)+x\times H(x) \] \[\begin{aligned}\operatorname{DFT}(f(\omega_n^k))&=\operatorname{DFT}(G((\omega_n^k)^2))+\omega_n^k\times \operatorname{DFT}(H((\omega_n^k)^2))\\ &=\operatorname{DFT}(G(\omega_n^{2k})) + \omega_n^k\times \operatorname{DFT}(H(\omega_n^{2k}))\\ &=\operatorname{DFT}(G(\omega_{n/2}^k)) + \omega_n^k \times \operatorname{DFT}(H(\omega_{n/2}^k))\end{aligned} \]同理可得:
\[\begin{aligned}\operatorname{DFT}(f(\omega_n^{k+n/2}))&=\operatorname{DFT}(G(\omega_n^{2k+n}))+\omega_n^{k+n/2}\times \operatorname{DFT}(H(\omega_n^{2k+n}))\\&=\operatorname{DFT}(G(\omega_n^{2k}))-\omega_n^k\times \operatorname{DFT}(H(\omega_n^{2k}))\\&=\operatorname{DFT}(G(\omega_{n/2}^k))-\omega_n^k\times \operatorname{DFT}(H(\omega_{n/2}^k))\end{aligned} \]所以我们可以对 \(G\) 和 \(H\) 分别递归求解,得出 \(\operatorname{DFT}(f(\omega_{n}^{k}))\) 和 \(\operatorname{DFT}(f(\omega_{n}^{k+n/2}))\)
位逆序置换
DFT 很明显可以用递归求解,但它还可以继续优化。
考虑到递归的过程是不断对若干个长为 \(2^m\) 的子段拆分成两个长为 \(2^{m-1}\) 的子段,并且每次递归 \(m-1\) 直到 \(m=1\) ,那么可以考虑从 \(m=1\) 的状态开始合并。
容易发现每个数在拆分到 \(m=1\) 时的位置和原始位置的二进制是翻转的,那么求出 rev[i]
表示 \(i\) 的二进制反转后的值即可。
这里给出代码:
Code-rev
for (int i = 0; i < len; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (num - 1));
IDFT
IDFT 相当于已知 \(y_i=f(\omega_{n}^i)\) ,\(i\in\{0,1,\dots,n-\}\) ,求 \(\{a_0,a_1,\dots,a_{n-1}\}\)
我们取单位根的倒数,跑一边 FFT,再将求得的 \(y\) 除以 \(n\) ,得到原来的 \(a\) 。
证明略。
Code-FFT
void FFT(Fu *y, int on) {
for (int i = 0; i < len; ++i)
if (i < rev[i]) swap(y[i], y[rev[i]]);
for (int h = 2; h <= len; h <<= 1) {
Fu wn = (Fu) {cos(2 * PI / h), sin(on * 2 * PI / h)};
for (int j = 0; j < len; j += h) {
Fu w = (Fu) {1, 0};
for (int k = j; k < j + h / 2; ++k) {
Fu u = y[k], t = w * y[k + h / 2];
y[k] = u + t, y[k + h / 2] = u - t;
w = w * wn;
}
}
}
if (on == -1)
for (int i = 0; i < len; ++i)
y[i].x /= len;
}
例题
P3803
模板题,给出代码。
Code
#include <bits/stdc++.h>
#define LL long long
#define ULL unsigned long long
#define DB double
#define PR pair <int, int>
#define MK make_pair
#define pb push_back
#define fi first
#define se second
#define RI register int
#define Low(x) (x & (-x))
using namespace std;
const int kN = 4e6 + 5;
const DB PI = acos(-1.0);
int n, m, len = 1, num = 0, rev[kN];
struct Fu {
DB x, y;
Fu operator + (const Fu &K) const {return (Fu) {x + K.x, y + K.y};}
Fu operator - (const Fu &K) const {return (Fu) {x - K.x, y - K.y};}
Fu operator * (const Fu &K) const {return (Fu) {x * K.x - y * K.y, x * K.y + y * K.x};}
} f[kN], g[kN];
void FFT(Fu *y, int on) {
for (int i = 0; i < len; ++i)
if (i < rev[i]) swap(y[i], y[rev[i]]);
for (int h = 2; h <= len; h <<= 1) {
Fu wn = (Fu) {cos(2 * PI / h), sin(on * 2 * PI / h)};
for (int j = 0; j < len; j += h) {
Fu w = (Fu) {1, 0};
for (int k = j; k < j + h / 2; ++k) {
Fu u = y[k], t = w * y[k + h / 2];
y[k] = u + t, y[k + h / 2] = u - t;
w = w * wn;
}
}
}
if (on == -1)
for (int i = 0; i < len; ++i)
y[i].x /= len;
}
signed main() {
scanf("%d%d", &n, &m);
for (int i = 0; i <= n; ++i) scanf("%lf", &f[i].x);
for (int i = 0; i <= m; ++i) scanf("%lf", &g[i].x);
while (len <= n + m) len <<= 1, num++;
for (int i = 0; i < len; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (num - 1));
FFT(f, 1), FFT(g, 1);
for (int i = 0; i < len; ++i) f[i] = f[i] * g[i];
FFT(f, -1);
for (int i = 0; i <= n + m; ++i) printf("%d ", (int) (f[i].x + 0.5));
return 0;
}