多项式入门

https://www.luogu.com.cn/blog/command-block/ntt-yu-duo-xiang-shi-quan-jia-tong

FFT

\(\bullet\) 单位根

\(n\) 次单位根指 \(n\) 次幂为 \(1\) 的复数。

将单位圆 \(n\) 等分,圆周上辐角为 \(\dfrac {k\pi} {n}(k\in[0,n))\) 的复数均为 \(n\) 次单位根。

一些性质:

\(\cdot\) \(\omega_n^0 = 1,\omega_n^1 = cos \dfrac {2\pi} {n} + i sin \dfrac {2\pi} n\)

\(\cdot\) \(\omega_n^k = (\omega_n^1)^k\)

\(\cdot\) \(\omega_{2n}^{2k} = \omega_n^k\)

\(\cdot\) 若 \(n\) 为偶数,\(\omega_n^k = -\omega_n^{k+\frac n 2}\)

\(\bullet\) \(\texttt{DFT}\)

将 \(n(2|n)\) 项多项式 \(F\) 按下标分成两部分:

\(F_0 = a_0 + a_2x + a_4x^2+\cdots+a_{n-2}x^{n/2-1}\)

\(F_1 = a_1 + a_3x + a_5x^2+\cdots+a_{n-1}x^{n/2-1}\)

则 \(F(x) = F_0(x^2) + xF_1(x^2)\)。

将 \(\omega_n^k(k<n/2)\) 代入:

\(\begin{aligned}F(\omega_n^k) &= F_0(\omega_n^{2k}) + \omega_n^kF_1(\omega_n^{2k})\\&=F_0(\omega_{\frac n 2}^k) + \omega_n^kF_1(\omega_{\frac n 2}^k)\end{aligned}\)

将 \(\omega_n^{k+n/2}(k<n/2)\) 代入:

\(\begin{aligned}F(\omega_n^{k+n/2})&= F_0(\omega_n^{2k+n}) + \omega_n^{k + n / 2}F_1(\omega_n^{2k+n})\\&=F_0(\omega_n^{2k})-\omega_n^kF_1(\omega_n^{2k})\\&=F_0(\omega_{\frac n 2}^k) - \omega_n^kF_1(\omega_{\frac n 2}^k)\end{aligned}\)

对比两式,只有一个正负号的区别,每个区间只需要求一半,可以做到 \(O(n\log n)\)。

蝴蝶变换:(不递归实现,减小常数)

多项式入门

从下往上合并即可。

\(\bullet\) \(\texttt{IDFT}\)

将 \(\texttt{DFT}\) 过程中的单位根变成 \(\omega_n^k(k\in (-n,0])\) 做一遍,然后将系数除以 \(n\) 即可。

\(\texttt{板子}\)

\(\texttt{Code:}\)

#include <bits/stdc++.h>
using namespace std;
const double pi = acos(-1);
const int N = (1 << 21) + 5;
struct Complex {double x, y;} F[N], G[N];
int n, m, tr[N];
Complex operator + (Complex a, Complex b) {
	return (Complex){a.x + b.x, a.y + b.y};
}
Complex operator - (Complex a, Complex b) {
	return (Complex){a.x - b.x, a.y - b.y};
}
Complex operator * (Complex a, Complex b) {
	return (Complex){a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x};
}
void fft(Complex *f, int tp) {
	for (int i = 0; i < n; i++)
	if (i < tr[i]) swap(f[tr[i]], f[i]);
	for (int len = 2; len <= n; len <<= 1) {
		Complex w0 = (Complex){cos(2.0 * pi / len), sin(2.0 * pi / len) * tp};
		for (int s = 0; s < n; s += len) {
			Complex w = (Complex){1, 0};
			for (int i = 0; i < len / 2; i++) {
				Complex cur = w * f[s + i + (len >> 1)];
				f[s + i + (len >> 1)] = f[s + i] - cur;
				f[s + i] = f[s + i] + cur;
				w = w * w0;
			}
		}
	}
}
int main() {
	scanf("%d%d", &n, &m);
	for (int i = 0; i <= n; i++) scanf("%lf", &F[i].x);
	for (int i = 0; i <= m; i++) scanf("%lf", &G[i].x);
	m += n; n = 1; while (n <= m) n <<= 1;
	tr[0] = 0;
	for (int i = 1; i < n; i++)
		tr[i] = (tr[i >> 1] >> 1) | (i & 1 ? n >> 1 : 0);
	fft(F, 1); fft(G, 1);
	for (int i = 0; i < n; i++) F[i] = F[i] * G[i];
	fft(F, -1);
	for (int i = 0; i <= m; i++) printf("%d ", (int)(F[i].x / n + 0.5));
	return 0;
}

NTT

\(\bullet\) 原根

设 \(p\) 是正整数,\(a\) 是整数,若 \(a\) 模 \(p\) 的阶等于\(\varphi(p)\),则称 \(a\) 为模 \(p\) 的一个原根

原根也具有单位根所利用到的性质。(不然怎么能换

设 \(g\) 为 \(mod\) 的原根,将 \(\omega_n^1\) 替换为 \(g^{(mod - 1) / n}\),类似 \(fft\) 做一遍即可。

\(\texttt{Code:}\)

#include <bits/stdc++.h>
using namespace std;
const int cmd = 998244353;
const int N = (1 << 21) + 5;
int n, m, F[N], G[N], yg, invyg, tr[N];
int fpow(int a, int b) {
	int res = 1;
	for (; b; b >>= 1, a = 1ll * a * a % cmd)
		if (b & 1) res = 1ll * res * a % cmd;
	return res;
}
void ntt(int *f, bool tp) {
	for (int i = 0; i < n; i++)
	if (i < tr[i]) swap(f[i], f[tr[i]]);
	for (int len = 2; len <= n; len <<= 1) {
		int w0 = fpow(tp ? yg : invyg, (cmd - 1) / len);
		for (int s = 0; s < n; s += len) {
			int w = 1;
			for (int i = 0; i < (len >> 1); i++) {
				int cur = 1ll * w * f[s + i + (len >> 1)] % cmd;
				f[s + i + (len >> 1)] = (f[s + i] - cur + cmd) % cmd;
				f[s + i] = (f[s + i] + cur) % cmd;
				w = 1ll * w * w0 % cmd;
			}
		}
	}
}
int main() {
	scanf("%d%d", &n, &m);
	yg = 3; invyg = fpow(yg, cmd - 2);
	for (int i = 0; i <= n; i++) scanf("%d", &F[i]);
	for (int i = 0; i <= m; i++) scanf("%d", &G[i]);
	m += n; n = 1; for (; n <= m; n <<= 1);
	for (int i = 1; i < n; i++)
		tr[i] = (tr[i >> 1] >> 1) | (i & 1 ? n >> 1 : 0);
	ntt(F, 1); ntt(G, 1);
	for (int i = 0; i < n; i++) F[i] = 1ll * F[i] * G[i] % cmd;
	ntt(F, 0); int invn = fpow(n, cmd - 2);
	for (int i = 0; i <= m; i++) printf("%lld ", 1ll * F[i] * invn % cmd);
	return 0;
}
上一篇:数字信号处理学习笔记[3] 滤波与褶积,Z变换


下一篇:百面机器学习1——特征工程篇