不知道多久才能做完多项式全家桶 qaq
多项式乘法
快速傅里叶变换 (FFT)
直接上链接(
总的来说就是先 DFT 从系数表示法到点值表示法,再 IDFT 从点值表示法到系数表示法。
简单说一下不太理解的,在 DFT 中 \(\omega_n^k = -\omega_n^{k+\frac{n}{2}}\),其实就是在单位圆上旋转了 \(180°\)。
感性理解 DFT 和 IDFT 是逆运算,所以 \(\omega_n\) 就是相反数。
Code
#include <bits/stdc++.h>
#define ll long long
#define db double
#define gc getchar
#define pc putchar
using namespace std;
namespace IO
{
template <typename T>
void read(T &x)
{
x = 0; bool f = 0; char c = gc();
while(!isdigit(c)) f |= c == '-', c = gc();
while(isdigit(c)) x = x * 10 + c - '0', c = gc();
if(f) x = -x;
}
template <typename T>
void write(T x)
{
if(x < 0) pc('-'), x = -x;
if(x > 9) write(x / 10);
pc('0' + x % 10);
}
}
using namespace IO;
struct Complex
{
db x, y;
Complex(db _x = 0.0, db _y = 0.0)
{
x = _x, y = _y;
}
Complex operator + (const Complex b) const
{
return Complex(x + b.x, y + b.y);
}
Complex operator - (const Complex b) const
{
return Complex(x - b.x, y - b.y);
}
Complex operator * (const Complex b) const
{
return Complex(x * b.x - y * b.y, x * b.y + y * b.x);
}
};
const int N = 1e6 + 5;
const db pi = acos(-1.0);
int n, m;
Complex f[N << 2], g[N << 2];
int len = 1, bit, rev[N << 2];
void FFT(Complex a[], int type)
{
for(int i = 0; i < len; i++)
if(i < rev[i]) swap(a[i], a[rev[i]]);
for(int mid = 1; mid < len; mid <<= 1)
{
Complex wn(cos(pi / mid), type * sin(pi / mid));
for(int i = 0; i < len; i += (mid << 1))
{
Complex w(1, 0);
for(int j = 0; j < mid; j++, w = w * wn)
{
Complex x = a[i + j], y = w * a[i + mid + j];
a[i + j] = x + y;
a[i + mid + j] = x - y;
}
}
}
if(type == -1)
for(int i = 0; i < len; i++)
a[i].x /= len;
return;
}
int main()
{
read(n), read(m);
for(int i = 0; i <= n; i++) read(f[i].x);
for(int i = 0; i <= m; i++) read(g[i].x);
while(len <= n + m) len <<= 1, bit++;
for(int i = 0; i < len; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
FFT(f, 1);
FFT(g, 1);
for(int i = 0; i < len; i++) f[i] = f[i] * g[i];
FFT(f, -1);
for(int i = 0; i <= n + m; i++)
printf("%lld ", (ll)(f[i].x + 0.5));
pc('\n');
return 0;
}
// A.S.
快速数论变换 (NTT)
继续上链接(
就是把 FFT 中的 \(\omega_n\) 改为了 \(g^{\frac{p-1}{n}}\),其中 \(g\) 为模数的原根。
\[\omega_n \equiv g^{\frac{p-1}{n}}\ (\bmod p) \]证明就算了
然后在 IDFT 时就感性理解地取一下逆元就行了(
Code
#include <bits/stdc++.h>
#define ll long long
#define db double
#define gc getchar
#define pc putchar
#define swap(a, b) a ^= b ^= a ^= b
using namespace std;
namespace IO
{
template <typename T>
void read(T &x)
{
x = 0; bool f = 0; char c = gc();
while(!isdigit(c)) f |= c == '-', c = gc();
while(isdigit(c)) x = x * 10 + c - '0', c = gc();
if(f) x = -x;
}
template <typename T>
void write(T x)
{
if(x < 0) pc('-'), x = -x;
if(x > 9) write(x / 10);
pc('0' + x % 10);
}
}
using namespace IO;
const int N = 1e6 + 5;
const int p = 998244353;
const int G = 3;
const int Gi = 332748118;
int n, m, len = 1, bit;
ll f[N << 2], g[N << 2];
int rev[N << 2];
ll qpow(ll a, int b)
{
ll res = 1;
while(b)
{
if(b & 1) res = res * a % p;
a = a * a % p, b >>= 1;
}
return res % p;
}
void NTT(ll a[], int type)
{
for(int i = 0; i < len; i++)
if(i < rev[i]) swap(a[i], a[rev[i]]);
for(int mid = 1; mid < len; mid <<= 1)
{
ll wn = qpow(type == 1 ? G : Gi, (p - 1) / (mid << 1));
for(int i = 0; i < len; i += (mid << 1))
{
ll w = 1;
for(int j = 0; j < mid; j++, w = w * wn % p)
{
ll x = a[i + j], y = w * a[i + mid + j] % p;
a[i + j] = (x + y) % p;
a[i + mid + j] = (x - y + p) % p;
}
}
}
ll inv = qpow(len, p - 2);
if(type == -1)
for(int i = 0; i < len; i++)
f[i] = f[i] * inv % p;
return;
}
int main()
{
read(n), read(m);
for(int i = 0; i <= n; i++) read(f[i]);
for(int i = 0; i <= m; i++) read(g[i]);
while(len <= n + m) len <<= 1, bit++;
for(int i = 0; i < len; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
NTT(f, 1);
NTT(g, 1);
for(int i = 0; i < len; i++) f[i] = f[i] * g[i] % p;
NTT(f, -1);
for(int i = 0; i <= n + m; i++)
write(f[i]), pc(' ');
pc('\n');
return 0;
}
// A.S.
多项式求逆
对于一个多项式 \(F(x)\) 求 一个多项式 \(G(x)\),满足 \(F(x)*G(x)\equiv 1\ (\bmod x^n)\),系数对 \(998244353\) 取模。
就是多项式的逆元。
推一下式子
设
\[A(x)*B(x) \equiv 1\ (\bmod x^n)\\ A(x)*C(x)\equiv 1\ (\bmod x^{\frac{n}{2}}) \]那么
\[A(x)*(B(x)-C(x))\equiv 0\ (\bmod x^{\frac{n}{2}}) \\ B(x)-C(x)\equiv 0\ (\bmod x^{\frac{n}{2}}) \]我们要把模数改为 \(x^n\),只需要平方一下
\[[B(x)-C(x)]^2\equiv 0\ (\bmod x^n) \\ B^2(x)-2B(x)*C(x)+C^2(x)\equiv 0\ (\bmod x^n)\\ \]将等式左边乘上 \(A(x)\) 不影响等号
因为
\[A(x)*B(x)\equiv 1\ (\bmod x^n) \]所以可以将 \(A(x)*B(x)\) 都去掉
\[B(x)-B(x)*C(x)+A(x)*C^2(x)\equiv 0\ (\bmod x^n) \\ B(x)\equiv B(x)*C(x)-A(x)*C^2(x)\ (\bmod x^n) \]我们只需要求出 \(C(x)\),而 \(C(x)\) 与 \(B(x)\) 的形式是一样的,只是模数不一样,所以可以递归求解,当然也可以递推
复杂度 \(O(n\log n)\)
这里写的递推,但是常数好像不是很优秀(
\(bas\) 是当前多项式的项数,\(len\) 是当前多项式乘起来后的项数,也就是 \(2\times bas\)。
Code
#include <bits/stdc++.h>
#define ll long long
#define db double
#define gc getchar
#define pc putchar
#define swap(a, b) a ^= b ^= a ^= b
using namespace std;
namespace IO
{
template <typename T>
void read(T &x)
{
x = 0; bool f = 0; char c = gc();
while(!isdigit(c)) f |= c == '-', c = gc();
while(isdigit(c)) x = x * 10 + c - '0', c = gc();
if(f) x = -x;
}
template <typename T>
void write(T x)
{
if(x < 0) pc('-'), x = -x;
if(x > 9) write(x / 10);
pc('0' + x % 10);
}
}
using namespace IO;
const int N = 1e5 + 5;
const int p = 998244353;
const int G = 3;
const int Gi = 332748118;
int n;
ll a[N << 2], b[2][N << 2];
int rev[N << 2];
ll qpow(ll a, int b)
{
ll res = 1;
while(b)
{
if(b & 1) res = res * a % p;
a = a * a % p, b >>= 1;
}
return res;
}
void calcrev(int len, int bit)
{
for(int i = 0; i < len; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}
ll add(ll x, ll y)
{
return (x + y >= p) ? (x + y - p) : (x + y);
}
ll sub(ll x, ll y)
{
return add(x, p - y);
}
void NTT(ll a[], int len, int type)
{
for(int i = 0; i < len; i++)
if(i < rev[i]) swap(a[i], a[rev[i]]);
for(int mid = 1; mid < len; mid <<= 1)
{
ll wn = qpow(type == 1 ? G : Gi, (p - 1) / (mid << 1));
for(int i = 0; i < len; i += (mid << 1))
{
ll w = 1;
for(int j = 0; j < mid; j++, w = w * wn % p)
{
ll x = a[i + j], y = w * a[i + mid + j] % p;
a[i + j] = add(x, y);
a[i + mid + j] = sub(x, y);
}
}
}
ll inv = qpow(len, p - 2);
if(type == -1)
for(int i = 0; i < len; i++)
a[i] = a[i] * inv % p;
return;
}
ll X[N << 2], Y[N << 2];
void mul(ll x[], ll y[], int len)
{
memset(X, 0, sizeof(X)), memset(Y, 0, sizeof(Y));
for(int i = 0; i < (len >> 1); i++) X[i] = x[i], Y[i] = y[i];
NTT(X, len, 1);
NTT(Y, len, 1);
for(int i = 0; i < len; i++) X[i] = X[i] * Y[i] % p;
NTT(X, len, -1);
for(int i = 0; i < len; i++) x[i] = X[i];
return;
}
void solve()
{
int k = 0, bas = 1, bit = 1, len = 2;
b[k][0] = qpow(a[0], p - 2);
while(bas < (n << 1))
{
calcrev(len, bit);
k ^= 1;
for(int i = 0; i < bas; i++) b[k][i] = add(b[k ^ 1][i], b[k ^ 1][i]);
mul(b[k ^ 1], b[k ^ 1], len);
mul(b[k ^ 1], a, len);
for(int i = 0; i < bas; i++) b[k][i] = sub(b[k][i], b[k ^ 1][i]);
bas <<= 1, len <<= 1, bit++;
}
for(int i = 0; i < n; i++)
write(b[k][i]), pc(' ');
pc('\n');
return;
}
int main()
{
read(n);
for(int i = 0; i < n; i++) read(a[i]), a[i] %= p;
solve();
return 0;
}
// A.S.
多项式 ln
给定一个多项式 \(A(x)\),求一个多项式 \(B(x)\),满足 \(B(x)\equiv \ln A(x)\ (\bmod x^n)\)
设 \(f(x) = \ln x\)
\(\ln\) 不好处理,但是对 \(\ln\) 求导后就很好算了,\(f'(x)=\dfrac{1}{x}\)
所以将同余号两边同时求导,\(B'(x)\equiv f'(A(x))*A'(x)\) (复合函数求导)
因为 \(f'(A(x))=\dfrac{1}{A(x)}\)
所以 \(B'(x)\equiv \dfrac{A'(x)}{A(x)}\)
有了 \(B'(x)\) 后再积分求出 \(B(x)\)
求导公式:\((x^a)'=ax^{a-1}\)
积分公式:\(\int x^adx=\dfrac{1}{a+1}x^{a+1}\)
积分就是求导的逆运算,你会发现把 \(\dfrac{1}{a+1}x^{a+1}\) 求导后就是 \(x^a\)
Code
#include <bits/stdc++.h>
#define ll long long
#define db double
#define gc getchar
#define pc putchar
#define swap(a, b) a ^= b ^= a ^= b
using namespace std;
namespace IO
{
template <typename T>
void read(T &x)
{
x = 0; bool f = 0; char c = gc();
while(!isdigit(c)) f |= c == '-', c = gc();
while(isdigit(c)) x = x * 10 + c - '0', c = gc();
if(f) x = -x;
}
template <typename T>
void write(T x)
{
if(x < 0) pc('-'), x = -x;
if(x > 9) write(x / 10);
pc('0' + x % 10);
}
}
using namespace IO;
const int N = 1e5 + 5;
const int p = 998244353;
const int G = 3;
const int Gi = 332748118;
int n;
ll f[N << 2], g[N << 2];
ll a[N << 2], b[N << 2];
ll qpow(ll a, int b)
{
ll res = 1;
while(b)
{
if(b & 1) res = res * a % p;
a = a * a % p, b >>= 1;
}
return res;
}
ll add(ll x) {return x >= p ? x - p : x; }
ll sub(ll x) {return x < 0 ? x + p : x; }
void Copy(ll *x, ll *y, int len) {for(int i = 0; i < len; i++) x[i] = y[i]; }
int rev[N << 2];
void calcrev(int len, int bit)
{
for(int i = 0; i < len; i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}
void NTT(ll *a, int len, int type) //快速数论变换
{
for(int i = 0; i < len; i++)
if(i < rev[i]) swap(a[i], a[rev[i]]);
for(int mid = 1; mid < len; mid <<= 1)
{
ll wn = qpow(type == 1 ? G : Gi, (p - 1) / (mid << 1));
for(int i = 0; i < len; i += (mid << 1))
{
ll w = 1;
for(int j = 0; j < mid; j++, w = w * wn % p)
{
ll x = a[i + j], y = w * a[i + mid + j] % p;
a[i + j] = add(x + y);
a[i + mid + j] = sub(x - y);
}
}
}
ll leni = qpow(len, p - 2);
if(type == -1)
for(int i = 0; i < len; i++)
a[i] = a[i] * leni % p;
return;
}
ll X[N << 2], Y[N << 2];
void mul(ll *x, ll *y, int len) //多项式乘法
{
memset(X, 0, sizeof(X)), memset(Y, 0, sizeof(Y));
Copy(X, x, len >> 1), Copy(Y, y, len >> 1);
NTT(X, len, 1);
NTT(Y, len, 1);
for(int i = 0; i < len; i++) X[i] = X[i] * Y[i] % p;
NTT(X, len, -1);
Copy(x, X, len);
}
ll inv[2][N << 2];
void Inv(ll *x, ll *y, int n) //多项式求逆
{
int bas = 1, len = 2, bit = 1, k = 0;
inv[k][0] = qpow(x[0], p - 2);
while(bas < (n << 1))
{
calcrev(len, bit);
k ^= 1;
for(int i = 0; i < bas; i++) inv[k][i] = add(inv[k ^ 1][i] + inv[k ^ 1][i]);
mul(inv[k ^ 1], inv[k ^ 1], len);
mul(inv[k ^ 1], x, len);
for(int i = 0; i < bas; i++) inv[k][i] = sub(inv[k][i] - inv[k ^ 1][i]);
bas <<= 1, len <<= 1, bit++;
}
Copy(y, inv[k], n);
}
void Differential(ll *x, ll *y, int n) //求导
{
for(int i = 1; i < n; i++)
y[i - 1] = i * x[i] % p;
y[n - 1] = 0;
}
void Integral(ll *x, ll *y, int n) //积分
{
for(int i = 1; i < n; i++)
y[i] = x[i - 1] * qpow(i, p - 2) % p;
y[0] = 0;
}
int calclen(int n)
{
int len = 1;
while(len <= (n << 1)) len <<= 1;
return len;
}
void Ln(ll *x, ll *y, int n)
{
Differential(x, a, n);
Inv(x, b, n);
mul(a, b, calclen(n));
Integral(a, y, n);
}
int main()
{
read(n);
for(int i = 0; i < n; i++) read(f[i]);
Ln(f, g, n);
for(int i = 0; i < n; i++)
write(g[i]), pc(' ');
pc('\n');
return 0;
}
// A.S.