【数学】多项式(模板)

const int MAXN = 1 << 21;
const int MOD = 998244353;
const int G = 3;

inline int qadd(const int &x, const int &y) {
    int r = x + y;
    return r >= MOD ? r - MOD : r;
}

inline int qsub(const int &x, const int &y) {
    int r = x - y;
    return r < 0 ? r + MOD : r;
}

inline int qmul(const int &x, const int &y) {
    ll r = 1LL * x * y;
    if(r >= MOD)
        r %= MOD;
    return r;
}

int qpow(ll x, ll n) {
    ll res = 1;
    while(n) {
        if(n & 1)
            res = qmul(res, x);
        x = qmul(x, x);
        n >>= 1;
    }
    return res;
}

using poly_t = int[MAXN];
using poly = int *const;

/* Show h[0...n-1], be careful that the length of h is less than n. */
void polyshow(poly &h, int n) {
    for(int i = 0; i != n; ++i)
        printf("%d%c", h[i], " \n"[i == n - 1]);
}

void FNTT(poly &h, int n, int op) {
    for(int i = 1, j = n >> 1, k; i < n - 1; ++i, j += k) {
        if(i < j)
            swap(h[i], h[j]);
        for(k = n >> 1; k <= j; j -= k, k >>= 1);
    }
    for(int l = 1; (1 << l) <= n; ++l) {
        for(int i = 0, Gl = qpow(G, (MOD - 1) / (1 << l)); i < n; i += (1 << l)) {
            for(int j = i, w = 1; j < i + (1 << (l - 1)); ++j, w = qmul(w, Gl)) {
                int u = h[j], t = qmul(h[j + (1 << (l - 1))], w);
                h[j] = qadd(u, t), h[j + (1 << (l - 1))] = qsub(u, t);
            }
        }
    }
    if(op == -1) {
        reverse(h + 1, h + n);
        for(int i = 0, inv = qpow(n, MOD - 2); i < n; ++i)
            h[i] = qmul(h[i], inv);
    }
}

/* Enlarge n to the smallest power of 2. */
void pretreat(poly &h, int &n) {
    int tn = 1;
    while(tn < n)
        tn <<= 1;
    fill(h + n, h + tn, 0);
    n = tn;
}

/* Add h1 and h2, and store the result in f. */
int polyadd(poly &h1, int n1, poly &h2, int n2, poly &f) {
    int n = max(n1, n2);
    fill(h1 + n1, h1 + n, 0);
    fill(h2 + n2, h2 + n, 0);
    for(int i = 0; i != n; ++i)
        f[i] = qadd(h1[i], h2[i]);
//    while(n > 0 && f[n - 1] == 0)
//        --n;
    return n;
}

/* Substract h2 from h1, and store the result in f. */
int polysub(poly &h1, int n1, poly &h2, int n2, poly &f) {
    int n = max(n1, n2);
    fill(h1 + n1, h1 + n, 0);
    fill(h2 + n2, h2 + n, 0);
    for(int i = 0; i != n; ++i)
        f[i] = qsub(h1[i], h2[i]);
//    while(n > 0 && f[n - 1] == 0)
//        --n;
    return n;
}

/* Multiply h1 and h2, and store the result in f. */
int polymul(poly &h1, int n1, poly &h2, int n2, poly &f) {
    int n = n1 + n2 - 1, tn = 1;
    while(tn < n)
        tn <<= 1;
    fill(h1 + n1, h1 + tn, 0), FNTT(h1, tn, 1);
    fill(h2 + n2, h2 + tn, 0), FNTT(h2, tn, 1);
    for(int i = 0; i != tn; ++i)
        f[i] = qmul(h1[i], h2[i]);
    FNTT(f, tn, -1);
//    while(n > 0 && f[n - 1] == 0)
//        --n;
    return n;
}

/* The following methods are solved in the sense of modulo x^n */

/* Get the inverse of h, and store the result in f. */
void polyinv(poly &h, int n, poly &f) {
    pretreat(h, n);
    static poly_t tmp;
    fill(f, f + n + n, 0), f[0] = qpow(h[0], MOD - 2);
    for(int t = 2; t <= n; t <<= 1) {
        copy(h, h + t, tmp), fill(tmp + t, tmp + t + t, 0);
        FNTT(f, t + t, 1), FNTT(tmp, t + t, 1);
        for(int i = 0; i != t + t; ++i)
            f[i] = qmul(f[i], qsub(2, qmul(f[i], tmp[i])));
        FNTT(f, t + t, -1);
        fill(f + t, f + t + t, 0);
    }
}

/* Get the derivative of h, and store the result in f. */
void polyder(poly &h, int n, poly &f) {
    for(int i = 1; i != n; ++i)
        f[i - 1] = qmul(h[i], i);
    f[n - 1] = 0;
}

/* Get the integral of h, and store the result in f. */
void polyint(poly &h, int n, poly &f, int C = 0) {
    static int inv[MAXN];
    if(inv[0] == 0) {
        inv[1] = 1;
        for(int i = 2; i < MAXN; ++i)
            inv[i] = qmul(inv[MOD % i], (MOD - MOD / i));
        inv[0] = -1;
    }
    for(int i = n - 1; i != 0; --i)
        f[i] = qmul(h[i - 1], inv[i]); /* or inv[i] = qpow(i, MOD - 2) */
    f[0] = C; /* constant C */
}

/* Get the logarithm of h, and store the result in f. */
void polylog(poly &h, int n, poly &f) {
    assert(h[0] == 1);
    pretreat(h, n);
    static poly_t tmp;
    polyder(h, n, tmp), polyinv(h, n, f);
    fill(tmp + n, tmp + n + n, 0);
    FNTT(tmp, n + n, 1), FNTT(f, n + n, 1);
    for(int i = 0; i != n + n; ++i)
        tmp[i] = qmul(tmp[i], f[i]);
    FNTT(tmp, n + n, -1);
    polyint(tmp, n, f);
}

/* Get the exponent of h, and store the result in f. */
void polyexp(poly &h, int n, poly &f) {
    assert(h[0] == 0);
    pretreat(h, n);
    static poly_t tmp;
    fill(f, f + n + n, 0), f[0] = 1;
    for(int t = 2; t <= n; t <<= 1) {
        polylog(f, t, tmp);
        tmp[0] = qsub(qadd(h[0], 1), tmp[0]);
        for(int i = 1; i != t; ++i)
            tmp[i] = qsub(h[i], tmp[i]);
        fill(tmp + t, tmp + t + t, 0);
        FNTT(f, t + t, 1), FNTT(tmp, t + t, 1);
        for(int i = 0; i != t + t; ++i)
            f[i] = qmul(f[i], tmp[i]);
        FNTT(f, t + t, -1);
        fill(f + t, f + t + t, 0);
    }
}
上一篇:SpringMVC-Controller


下一篇:实验2-2-9 计算火车运行时间 (15 分)