Pre
式子变换需要注意一下
Solution
注意到\(f(x)g(x)+f_0=f(x)\)
其实开始我没看出来,后来发现仔细分析一下就可以了。
然后式子变换
\(f(x)=\frac{f_0}{1-g(x)}\)
注意这里的\(1-\)是只减常数项,因为这里的\(f(x)\)和\(g(x)\)是指的函数,而不是系数。
Code
#include <cstdio>
#include <queue>
#include <cstring>
#define ll long long
#define xx first
#define yy second
using namespace std;
inline void swap (int &a, int &b) {
int c = a;
a = b,
b = c;
}
const int N = 250000 + 5, mod = 998244353, inver = 332748118;
int nn, g[N], f[N];
inline int add (int u, int v) {return u + v >= mod ? u + v - mod : u + v;}
inline int mns (int u, int v) {return u - v < 0 ? u - v + mod : u - v;}
inline int mul (int u, int v) {return 1LL * u * v % mod;}
inline int qpow (int u, int v) {
int tot = 1, base = u % mod;
while (v){
if (v & 1) tot = mul (tot, base);
base = mul (base, base);
v >>= 1;
}
return tot;
}
int c[N], rev[N];
inline void NTT (int *a, int n, int bit, bool flag) {
for (int i = 0; i < n; ++i) {
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
if (i > rev[i]) swap (a[i], a[rev[i]]);
}
for (int l = 2; l <= n; l <<= 1) {
int wi = qpow (flag ? inver : 3, (mod - 1) / l);
int m = l / 2;
for (int *k = a; k != a + n; k += l) {
int w = 1;
for (int i = 0; i < m; ++i) {
int tmp = mul (k[i + m], w);
k[i + m] = mns (k[i], tmp);
k[i] = add (k[i], tmp);
w = mul (w, wi);
}
}
}
int tmp = qpow (n, mod - 2);
for (int i = 0; i < n && flag; ++i) {
a[i] = mul (a[i], tmp);
}
}
inline void Inv (int *a, int *b, int deg) {
if (deg == 1) {
b[0] = qpow (a[0], mod - 2);
return ;
}
Inv (a, b, (deg + 1) >> 1);
int n = 1, bit = 0;
while (n < (deg << 1)) n <<= 1, ++bit;
for (int i = 0; i < deg; ++i) c[i] = a[i]; for (int i = deg; i < n; ++i) c[i] = 0;
NTT (c, n, bit, false);
NTT (b, n, bit, false);
for (int i = 0; i < n; ++i) b[i] = mns (mul (2, b[i]), mul (c[i], mul (b[i], b[i])));
NTT (b, n, bit, true);
for (int i = deg; i < n; ++i) b[i] = 0;
}
int main () {
#ifdef chitongz
freopen ("x.in", "r", stdin);
#endif
scanf ("%d", &nn);
for (int i = 1; i <= nn - 1; ++i) scanf ("%d", &g[i]), g[i] = mns (mod, g[i]);
g[0] = add (g[0], 1);
Inv (g, f, nn);
for (int i = 0; i < nn; ++i) printf ("%d ", f[i]);
puts ("");
return 0;
}
Conclusion
注意一下什么时候系数减法,什么时候常熟减法。