Preface
最近几天学了一下FTT和NTT,感觉这东西理解了之后也没有那么难 其实我IDFT还不会证明
我本来是准备写一篇特别详细的总结,结果发现了一篇和我想写的内容相近的博客 传送门
以及一篇只需初中数学知识的零基础学习笔记 传送门
所以我就只讲一下大致的算法过程,具体可以去看一下链接的两篇博客
先了解复数相关的性质和运算以及单位圆后,食用效果更佳
Problem
给你两个多项式,求它们的卷积
Process
1.DFT
将给定的两个多项式从系数式转换为点值式
点值式的意思是在平面上找到\(n + 1\)个横坐标不同的点来确定一个\(n\)次函数,即\(F(x)\)对于若干个不同\(x\)的取值
直接暴力显然是不行的,所以我们要想办法优化
我们可以考虑把当前的问题转化为子问题,然后再从子问题快速求解当前的问题
假设我们现在要求\(F(x) = \sum\limits_{i = 0}^{n - 1} a_i \cdot x^i\)的点值式,保证\(n = 2^k (k \in N)\)
我们可以把式子变一下形
\[
F(x) = (a_0 + a_2 \cdot x^2 + \cdots + a_{n - 2} \cdot x^{n - 2}) + (a_1 \cdot x + a_3 \cdot x^3 + \cdots + a_{n - 1} \cdot x^{n - 1})
\]
我们令
\[
G(x) = a_0 + a_2 \cdot x + \cdots + a_{n - 2} \cdot x^{\frac{n}{2} - 1} \\
G'(x) = a_1 + a_3 \cdot x + \cdots + a_{n - 1} \cdot x^{\frac{n}{2} - 1}
\]
则
\[
F(x) = G(x^2) + x \cdot G'(x^2)
\]
但这样好像还是没有转换为一模一样的子问题,所以我们可以考虑带一些具有(qi)某些(qi)特殊(guai)性质(guai)的数值进去
在经过前人无数次尝试之后,发现可以代入\(\omega\)到式子里去,这是因为\(\omega\)有一些比较神奇的性质
\(\omega\)的本质是一个复数,且满足\(\omega^n = 1\),所以显然\(\omega\)只能在单位圆上
于是我们记\(\omega_n^k (k \in [0, n))\)为单位根,如果我们把这些单位根看成矢量,那么它们便会\(n\)等分这个单位圆
它有这样一些性质:(字母均为整数)
- \(\omega_n^k = \omega_n^{k + a \cdot n}\)
- \(\omega_n^{k_1} \cdot \omega_n^{k_2}= \omega_n^{k_1 + k_2}\)
- \(\omega_{d \cdot n}^{d \cdot k} = \omega_n^k\)
- \(\omega_n^{k + \frac{n}{2}} = - \omega_n^{k}\)
所以,(保证\(k < \frac{n}{2}\))
\[
F(\omega_n^k) = G((\omega_n^k)^2) + \omega_n^k \cdot G'((\omega_n^k)^2) \\
= G(\omega_n^{2k}) + \omega_n^k \cdot G'(\omega_n^{2k}) \\
= G(\omega_{n / 2}^k) + \omega_n^k \cdot G'(\omega_{n / 2}^k)
\]
因为\(\omega_n^{k + n / 2} = - \omega_n^k\),所以
\[
F(\omega_n^{k + n / 2}) = G(\omega_{n / 2}^k) - \omega_n^k \cdot G'(\omega_{n / 2}^k)
\]
综上,\(F\)函数的点值均可从函数\(G\)和\(G'\)转移过来,复杂度\(\mathcal{O}(n)\)
分治后总时间复杂度\(\mathcal{O}(n \log n)\) (此\(n\)非彼\(n\))
2. 点值式相乘
我们直接把两个点值式乘起来即为它们卷集的点值式
3. IDFT
将答案的点值式转换为系数式,经过一番矩阵的巧妙证明之后 假装自己会
只需要令\(\omega_n^k = \omega_n^{-k}\)后进行一次DFT即为IDFT
至此,我们便得到了答案多项式的系数
Code
FFT
#include <bits/stdc++.h>
using namespace std;
#define fst first
#define snd second
#define mp make_pair
#define squ(x) ((LL)(x) * (x))
#define debug(...) fprintf(stderr, __VA_ARGS__)
typedef long long LL;
typedef pair<int, int> pii;
template<typename T> inline bool chkmax(T &a, const T &b) { return a < b ? a = b, 1 : 0; }
template<typename T> inline bool chkmin(T &a, const T &b) { return a > b ? a = b, 1 : 0; }
inline int read() {
int sum = 0, fg = 1; char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') fg = -1;
for (; isdigit(c); c = getchar()) sum = (sum << 3) + (sum << 1) + (c ^ 0x30);
return fg * sum;
}
namespace FFT {
const int MAX_LEN = 1 << 21;
const double PI = acos(-1.0);
struct com {
double a, b;
com (double _a = 0.0, double _b = 0.0): a(_a), b(_b) { }
com operator + (const com &t) const { return com(a + t.a, b + t.b); }
com operator - (const com &t) const { return com(a - t.a, b - t.b); }
com operator * (const com &t) const { return com(a * t.a - b * t.b, a * t.b + b * t.a); }
};
int len, cnt, rev[MAX_LEN];
com g[MAX_LEN];
void init(int N) {
for (cnt = -1, len = 1; len <= N; len <<= 1) ++cnt;
for (int i = 0; i < len; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << cnt);
g[0] = com(1.0, 0.0);
com G(cos(PI * 2 / len), sin(PI * 2 / len));
for (int i = 1; i < len; i++) g[i] = g[i - 1] * G;
}
void DFT(com *x, int op) {
for (int i = 0; i < len; i++) if (i < rev[i]) swap(x[i], x[rev[i]]);
for (int k = 2; k <= len; k <<= 1)
for (int j = 0; j < len; j += k)
for (int i = 0; i < k / 2; i++) {
com X = x[j + i], Y = x[j + i + k / 2] * g[~op ? len / k * i : len / k * (i ? k - i : i)];
x[j + i] = X + Y, x[j + i + k / 2] = X - Y;
}
if (op == -1) for (int i = 0; i < len; i++) x[i].a /= len;
}
void mul(int *a, int n, int *b, int m, int *c) {
init(n + m);
static com F[MAX_LEN], G[MAX_LEN], S[MAX_LEN];
for (int i = 0; i < len; i++) F[i] = com(i <= n ? a[i] : 0.0, 0.0);
for (int i = 0; i < len; i++) G[i] = com(i <= m ? b[i] : 0.0, 0.0);
DFT(F, 1), DFT(G, 1);
for (int i = 0; i < len; i++) S[i] = F[i] * G[i];
DFT(S, -1);
for (int i = 0; i <= n + m; i++) c[i] = round(S[i].a);
}
}
const int maxn = 2e6 + 10;
int main() {
#ifdef xunzhen
freopen("FFT.in", "r", stdin);
freopen("FFT.out", "w", stdout);
#endif
int n = read(), m = read();
static int a[maxn], b[maxn], c[maxn];
for (int i = 0; i <= n; i++) a[i] = read();
for (int i = 0; i <= m; i++) b[i] = read();
FFT::mul(a, n, b, m, c);
for (int i = 0; i <= n + m; i++) printf("%d%c", c[i], i < n + m ? ' ' : '\n');
return 0;
}
NTT
#include <bits/stdc++.h>
using namespace std;
#define fst first
#define snd second
#define mp make_pair
#define squ(x) ((LL)(x) * (x))
#define debug(...) fprintf(stderr, __VA_ARGS__)
typedef long long LL;
typedef pair<int, int> pii;
template<typename T> inline bool chkmax(T &a, const T &b) { return a < b ? a = b, 1 : 0; }
template<typename T> inline bool chkmin(T &a, const T &b) { return a > b ? a = b, 1 : 0; }
inline int read() {
int sum = 0, fg = 1; char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') fg = -1;
for (; isdigit(c); c = getchar()) sum = (sum << 3) + (sum << 1) + (c ^ 0x30);
return fg * sum;
}
namespace NTT {
const int MAX_LEN = 1 << 21;
const int mod = 998244353, g0 = 3;
int len, cnt, rev[MAX_LEN], g[MAX_LEN];
inline int add(int x, int y) { return (x += y) < mod ? (x >= 0 ? x : x + mod) : x - mod; }
inline int mul(int x, int y) { return (LL)x * y % mod; }
inline int Pow(int x, int y) {
if (y < 0) y = -1LL * y * (mod - 2) % (mod - 1);
int res = 1;
for (; y; y >>= 1, x = mul(x, x)) if (y & 1) res = mul(res, x);
return res;
}
void init(int N) {
for (cnt = -1, len = 1; len <= N; len <<= 1) ++cnt;
for (int i = 0; i < len; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << cnt);
g[0] = 1;
for (int G = Pow(g0, (mod - 1) / len), i = 1; i < len; i++) g[i] = mul(g[i - 1], G);
}
void DFT(int *x, int op) {
for (int i = 0; i < len; i++) if (i < rev[i]) swap(x[i], x[rev[i]]);
for (int k = 2; k <= len; k <<= 1)
for (int j = 0; j < len; j += k)
for (int i = 0; i < k / 2; i++) {
int X = x[j + i], Y = mul(x[j + i + k / 2], g[~op ? len / k * i : len / k * (i ? k - i : i)]);
x[j + i] = add(X, Y), x[j + i + k / 2] = add(X, -Y);
}
if (op == -1) for (int inv = Pow(len, -1), i = 0; i < len; i++) x[i] = mul(x[i], inv);
}
void mul(int *a, int n, int *b, int m, int *c) {
init(n + m);
static int F[MAX_LEN], G[MAX_LEN], S[MAX_LEN];
for (int i = 0; i < len; i++) F[i] = i <= n ? a[i] : 0;
for (int i = 0; i < len; i++) G[i] = i <= m ? b[i] : 0;
DFT(F, 1), DFT(G, 1);
for (int i = 0; i < len; i++) S[i] = mul(F[i], G[i]);
DFT(S, -1);
for (int i = 0; i <= n + m; i++) c[i] = S[i];
}
}
const int maxn = 2e6 + 10;
int main() {
#ifdef xunzhen
freopen("NTT.in", "r", stdin);
freopen("NTT.out", "w", stdout);
#endif
int n = read(), m = read();
static int a[maxn], b[maxn], c[maxn];
for (int i = 0; i <= n; i++) a[i] = read();
for (int i = 0; i <= m; i++) b[i] = read();
NTT::mul(a, n, b, m, c);
for (int i = 0; i <= n + m; i++) printf("%d%c", c[i], i < n + m ? ' ' : '\n');
return 0;
}
Summary
其实NTT就是把FFT在模意义下进行,我们可以找一个原根\(g\)来代替\(\omega\)
NTT可以用来避免浮点数的缓慢运算 但好像取模运算更满(雾
IDFT就先留个坑,等以后再来填算了