多项式模板

多项式乘法(NTT)

#include <bits/stdc++.h>

using namespace std;

namespace Polynomial {
    
    typedef long long LL;

    const int K = 21, N = 1 << K;
    const LL MOD = 998244353;

    LL Pow(LL a, LL b) {
        LL c = 1;
        for (; b; a = a * a % MOD, b >>= 1)
            if (b & 1) c = c * a % MOD;
        return c;
    }

    void Init(int k, int *sp, LL *w) {
        int n = 1 << k;
        w[0] = 1;
        w[1] = Pow(3, (MOD - 1) / n);
        for (int i = 2; i < n; ++i)
            w[i] = w[i - 1] * w[1] % MOD;
        sp[0] = 0;
        for (int i = 1; i < n; ++i)
            sp[i] = (sp[i >> 1] >> 1) | ((i & 1) ? (n >> 1) : 0);
        return;
    }

    void DFT(vector<LL> &a, int k) {
        static int _k = 0;
        static int _sp[N * 2], *sp[K];
        static LL _w[N * 2], *w[K];
        for (; _k < k; ) {
            ++_k;
            sp[_k] = _sp + (1 << _k) - 1;
            w[_k] = _w + (1 << _k) - 1;
            Init(_k, sp[_k], w[_k]);
        }
        int n = 1 << k;
        for (int i = 0; i < n; ++i)
            if (i < sp[k][i]) swap(a[i], a[sp[k][i]]);
        for (int j = 1; j <= k; ++j) {
            int l = 1 << j, m = l >> 1;
            for (vector<LL> :: iterator p = a.begin(); p != a.end(); p += l)
                for (int i = 0; i < m; ++i) {
                    LL t = p[i + m] * w[j][i] % MOD;
                    if ((p[i + m] = p[i] - t) < 0) p[i + m] += MOD;
                    if ((p[i] += t) >= MOD) p[i] -= MOD;
                }
        }
        return;
    }

    void IDFT(vector<LL> &a, int k) {
        DFT(a, k);
        int n = 1 << k, inv = Pow(n, MOD - 2);
        reverse(a.begin() + 1, a.end());
        for (int i = 0; i < n; ++i)
            (a[i] *= inv) %= MOD;
        return;
    }

    vector<LL> NTT(vector<LL> a, vector<LL> b) {
        int _n = a.size(), _m = b.size(), k, n;
        for (k = 1; (1 << k) < (_n + _m - 1); ++k);
        n = 1 << k;
        a.resize(n);
        b.resize(n);
        DFT(a, k);
        DFT(b, k);
        for (int i = 0; i < n; ++i)
            (a[i] *= b[i]) %= MOD;
        IDFT(a, k);
        a.resize(_n + _m - 1);
        return a;
    }

    void _NTT(vector<LL> &a, vector<LL> b) {
        int _n = a.size(), _m = b.size(), k, n;
        for (k = 1; (1 << k) < (_n + _m - 1); ++k);
        n = 1 << k;
        a.resize(n);
        b.resize(n);
        DFT(a, k);
        DFT(b, k);
        for (int i = 0; i < n; ++i)
            (a[i] *= b[i]) %= MOD;
        IDFT(a, k);
        a.resize(_n + _m - 1);
        return;
    }
    
    struct Poly {
        vector<LL> c;
        
        Poly(vector<LL> c = {0}) : c(c) {}

        Poly(LL x) {
            c.resize(1);
            c[0] = x;
            return;
        }

        Poly* operator = (const LL x) {
            c.resize(1);
            c[0] = x;
            return this;
        }
        
        inline int Deg() const {
            return (int)c.size() - 1;
        }
        
        inline void Clear() {
            c.clear();
            return;
        }

        void Read(int n) {
            c.clear();
            c.resize(n + 1);
            for (int i = 0; i <= n; ++i)
                scanf("%lld", &c[i]);
            return;
        }

        void Print() {
            for (LL x : c)
                printf("%lld ", x);
            puts("");
            return;
        }
        
    };
    
    void Read(int n, Poly &a) {
        a.Read(n);
        return;
    }

    void Print(Poly &a) {
        a.Print();
        return;
    }
    
    Poly operator * (const Poly &a, const Poly &b) {
        return Poly(NTT(a.c, b.c));
    }
    
    Poly* operator *= (Poly &a, const Poly &b) {
        _NTT(a.c, b.c);
        return &a;
    }
    
}

using namespace Polynomial;

Poly a, b;

int main() {
    int n, m;
    scanf("%d%d", &n, &m);
    Read(n, a);
    Read(m, b);
    a *= b;
    Print(a);
    return 0;
}

多项式快速幂

#include <bits/stdc++.h>
 
using namespace std;
 
namespace Polynomial {
     
    typedef long long LL;
 
    const int K = 21, N = 1 << K;
    const LL MOD = 998244353;
 
    LL Pow(LL a, LL b) {
        LL c = 1;
        for (; b; a = a * a % MOD, b >>= 1)
            if (b & 1) c = c * a % MOD;
        return c;
    }
 
    LL inv[N];
 
    void InitInv() {
        inv[1] = 1;
        for (int i = 2; i < N; ++i)
            inv[i] = (MOD - MOD / i) * inv[MOD % i] % MOD;
        return;
    }
 
    void Init(int k, int *sp, LL *w) {
        int n = 1 << k;
        w[0] = 1;
        w[1] = Pow(3, (MOD - 1) / n);
        for (int i = 2; i < n; ++i)
            w[i] = w[i - 1] * w[1] % MOD;
        sp[0] = 0;
        for (int i = 1; i < n; ++i)
            sp[i] = (sp[i >> 1] >> 1) | ((i & 1) ? (n >> 1) : 0);
        return;
    }
 
    void DFT(vector<LL> &a, int k) {
        static int _k = 0;
        static int _sp[N * 2], *sp[K];
        static LL _w[N * 2], *w[K];
        for (; _k < k; ) {
            ++_k;
            sp[_k] = _sp + (1 << _k) - 1;
            w[_k] = _w + (1 << _k) - 1;
            Init(_k, sp[_k], w[_k]);
        }
        int n = 1 << k;
        for (int i = 0; i < n; ++i)
            if (i < sp[k][i]) swap(a[i], a[sp[k][i]]);
        for (int j = 1; j <= k; ++j) {
            int l = 1 << j, m = l >> 1;
            for (vector<LL> :: iterator p = a.begin(); p != a.end(); p += l)
                for (int i = 0; i < m; ++i) {
                    LL t = p[i + m] * w[j][i] % MOD;
                    if ((p[i + m] = p[i] - t) < 0) p[i + m] += MOD;
                    if ((p[i] += t) >= MOD) p[i] -= MOD;
                }
        }
        return;
    }
 
    void IDFT(vector<LL> &a, int k) {
        DFT(a, k);
        int n = 1 << k, inv = Pow(n, MOD - 2);
        reverse(a.begin() + 1, a.end());
        for (int i = 0; i < n; ++i)
            (a[i] *= inv) %= MOD;
        return;
    }
 
    vector<LL> Multiply(vector<LL> a, vector<LL> b) {
        int _n = a.size(), _m = b.size(), k, n;
        for (k = 1; (1 << k) < (_n + _m - 1); ++k);
        n = 1 << k;
        a.resize(n);
        b.resize(n);
        DFT(a, k);
        DFT(b, k);
        for (int i = 0; i < n; ++i)
            (a[i] *= b[i]) %= MOD;
        IDFT(a, k);
        a.resize(_n + _m - 1);
        return a;
    }
    vector<LL> Multiply(vector<LL> a, vector<LL> b, int m) {
        if ((int)a.size() > m) a.resize(m);
        if ((int)b.size() > m) b.resize(m);
        int _n = a.size(), _m = b.size(), k, n;
        for (k = 1; (1 << k) < (_n + _m - 1); ++k);
        n = 1 << k;
        a.resize(n);
        b.resize(n);
        DFT(a, k);
        DFT(b, k);
        for (int i = 0; i < n; ++i)
            (a[i] *= b[i]) %= MOD;
        IDFT(a, k);
        a.resize(m);
        return a;
    }
    vector<LL>* _Multiply(vector<LL> &a, vector<LL> b) {
        int _n = a.size(), _m = b.size(), k, n;
        for (k = 1; (1 << k) < (_n + _m - 1); ++k);
        n = 1 << k;
        a.resize(n);
        b.resize(n);
        DFT(a, k);
        DFT(b, k);
        for (int i = 0; i < n; ++i)
            (a[i] *= b[i]) %= MOD;
        IDFT(a, k);
        a.resize(_n + _m - 1);
        return &a;
    }
    vector<LL>* _Multiply(vector<LL> &a, vector<LL> b, int m) {
        if ((int)b.size() > m) b.resize(m);
        int _n = a.size(), _m = b.size(), k, n;
        for (k = 1; (1 << k) < (_n + _m - 1); ++k);
        n = 1 << k;
        a.resize(n);
        b.resize(n);
        DFT(a, k);
        DFT(b, k);
        for (int i = 0; i < n; ++i)
            (a[i] *= b[i]) %= MOD;
        IDFT(a, k);
        a.resize(m);
        return &a;
    }
 
    vector<LL> Multiply(vector<LL> a, LL b) {
        b %= MOD;
        int n = a.size();
        for (int i = 0; i < n; ++i)
            a[i] = a[i] * b % MOD;
        return a;
    }
    vector<LL>* _Multiply(vector<LL> &a, LL b) {
        b %= MOD;
        int n = a.size();
        for (int i = 0; i < n; ++i)
            a[i] = a[i] * b % MOD;
        return &a;
    }
     
    vector<LL> Resize(vector<LL> a, int n) {
        return a.resize(n), a;
    }
    vector<LL>* _Resize(vector<LL> &a, int n) {
        a.resize(n);
        return &a;
    }
     
    vector<LL> Reciprocal(const vector<LL> &a) {
        int _n = a.size();
        if (!a[0]) {
            cerr << "irreversible" << endl;
            throw;
        }
        vector<LL> b(1, Pow(a[0], MOD - 2)), c;
        for (int n = 1, k = 0; n < _n; ) {
            n <<= 1;
            ++k;
            b.resize(n * 2);
            c.resize(n * 2);
            for (int i = 0; i < n; ++i)
                c[i] = (i < _n) ? a[i] : 0;
            DFT(b, k + 1);
            DFT(c, k + 1);
            for (int i = 0; i < n * 2; ++i)
                b[i] = b[i] * (MOD + 2 - b[i] * c[i] % MOD) % MOD;
            IDFT(b, k + 1);
            fill(b.begin() + n, b.end(), 0);
        }
        b.resize(_n);
        return b;
    }
 
    vector<LL> LeftShift(vector<LL> a, int n) {
        int _n = a.size();
        a.resize(_n + n);
        for (int i = _n + n - 1; i >= n; --i)
            a[i] = a[i - n];
        for (int i = 0; i < n; ++i)
            a[i] = 0;
        return a;
    }
    vector<LL>* _LeftShift(vector<LL> &a, int n) {
        int _n = a.size();
        a.resize(_n + n);
        for (int i = _n + n - 1; i >= n; --i)
            a[i] = a[i - n];
        for (int i = 0; i < n; ++i)
            a[i] = 0;
        return &a;
    }
     
    vector<LL> RightShift(vector<LL> a, int n) {
        int _n = a.size();
        for (int i = 0; i + n < _n; ++i)
            a[i] = a[i + n];
        if (_n - n > 0) a.resize(_n - n);
        else a = vector<LL>(1);
        return a;
    }
    vector<LL>* _RightShift(vector<LL> &a, int n) {
        int _n = a.size();
        for (int i = 0; i + n < _n; ++i)
            a[i] = a[i + n];
        if (_n - n > 0) a.resize(_n - n);
        else a = vector<LL>(1);
        return &a;
    }
     
    vector<LL> Add(vector<LL> a, const vector<LL> &b) {
        if (a.size() < b.size()) a.resize(b.size());
        for (int i = 0; i < (int)b.size(); ++i)
            if ((a[i] += b[i]) >= MOD) a[i] -= MOD;
        return a;
    }
    vector<LL>* _Add(vector<LL> &a, const vector<LL> &b) {
        if (a.size() < b.size()) a.resize(b.size());
        for (int i = 0; i < (int)b.size(); ++i)
            if ((a[i] += b[i]) >= MOD) a[i] -= MOD;
        return &a;
    }
 
    vector<LL> Add(vector<LL> a, LL b) {
        a[0] = (a[0] + b) % MOD;
        return a;
    }
    vector<LL>* _Add(vector<LL> &a, LL b) {
        a[0] = (a[0] + b) % MOD;
        return &a;
    }
     
    vector<LL> Subtract(vector<LL> a, const vector<LL> &b) {
        if (a.size() < b.size()) a.resize(b.size());
        for (int i = 0; i < (int)b.size(); ++i)
            if ((a[i] -= b[i]) < 0) a[i] += MOD;
        return a;
    }
    vector<LL>* _Subtract(vector<LL> &a, const vector<LL> &b) {
        if (a.size() < b.size()) a.resize(b.size());
        for (int i = 0; i < (int)b.size(); ++i)
            if ((a[i] -= b[i]) < 0) a[i] += MOD;
        return &a;
    }
     
    vector<LL> Subtract(vector<LL> a, LL b) {
        a[0] = (a[0] + MOD - b) % MOD;
        return a;
    }
    vector<LL>* _Subtract(vector<LL> &a, LL b) {
        a[0] = (a[0] + MOD - b) % MOD;
        return &a;
    }
     
    vector<LL> Divide(vector<LL> a, LL b) {
        LL inv = Pow(b % MOD, MOD - 2);
        int n = a.size();
        for (int i = 0; i < n; ++i)
            a[i] = a[i] * inv % MOD;
        return a;
    }
    vector<LL>* _Divide(vector<LL> &a, LL b) {
        LL inv = Pow(b % MOD, MOD - 2);
        int n = a.size();
        for (int i = 0; i < n; ++i)
            a[i] = a[i] * inv % MOD;
        return &a;
    }
     
    vector<LL> Divide(vector<LL> a, vector<LL> b) {
        int n = (int)a.size() - 1, m = (int)b.size() - 1;
        reverse(a.begin(), a.end());
        reverse(b.begin(), b.end());
        a.resize(n - m + 1);
        b.resize(n - m + 1);
        a = Multiply(a, Reciprocal(b), n - m + 1);
        reverse(a.begin(), a.end());
        return a;
    }
     
    vector<LL> Modulo(const vector<LL> &a, const vector<LL> &b) {
        int m = (int)b.size() - 1;
        return Resize(Subtract(a, Multiply(Divide(a, b), b)), m);
    }
     
    vector<LL> Derivative(vector<LL> a) {
        int n = a.size();
        for (int i = 0; i + 1 < n; ++i)
            a[i] = a[i + 1] * (i + 1) % MOD;
        if (n > 1) a.resize(n - 1);
        return a;
    }
    vector<LL>* _Derivative(vector<LL> &a) {
        int n = a.size();
        for (int i = 0; i + 1 < n; ++i)
            a[i] = a[i + 1] * (i + 1) % MOD;
        if (n > 1) a.resize(n - 1);
        return &a;
    }
 
    vector<LL> Integral(vector<LL> a) {
        int n = a.size();
        a.resize(n + 1);
        for (int i = n; i; --i)
            a[i] = a[i - 1] * inv[i] % MOD;
        a[0] = 0;
        return a;
    }
    vector<LL>* _Integral(vector<LL> &a) {
        int n = a.size();
        a.resize(n + 1);
        for (int i = n; i; --i)
            a[i] = a[i - 1] * inv[i] % MOD;
        a[0] = 0;
        return &a;
    }
     
    vector<LL> Logarithm(vector<LL> a) {
        if (a[0] != 1) {
            cerr << "logarithm function undefined" << endl;
            throw;
        }
        int n = a.size();
        return Integral(Multiply(Derivative(a), Reciprocal(a), n - 1));
    }
    vector<LL>* _Logarithm(vector<LL> &a) {
        if (a[0] != 1) {
            cerr << "logarithm function undefined" << endl;
            throw;
        }
        int n = a.size();
        return _Integral(*_Multiply(*_Derivative(a), Reciprocal(a), n - 1));
    }
 
    vector<LL> Exponential(vector<LL> a) {
        if (a[0]) {
            cerr << "exponential function undefined" << endl;
            throw;
        }
        int _n = a.size();
        vector<LL> b(1, 1);
        for (int n = 1; n < _n; ) {
            n <<= 1;
            b.resize(n);
            _Multiply(b, Add(Subtract(vector<LL>(1, 1), Logarithm(b)), Resize(a, n)), n);
        }
        return *_Resize(b, _n);
    }
     
    vector<LL> Pow(vector<LL> a, LL b) {
        int n = a.size(), k;
        for (k = 0; k < n && !a[k]; ++k);
        LL c = a[k];
        _RightShift(a, k);
        _Divide(a, c);
        a.resize(n);
        a = Exponential(*_Multiply(*_Logarithm(a), b));
        _Multiply(a, Pow(c, b));
        if (k * b >= n) return vector<LL>(n, 0);
        return *_Resize(*_LeftShift(a, k * b), n);
    }
         
    struct Poly {
        vector<LL> c;
         
        Poly(vector<LL> c = vector<LL>(1)) : c(c) {}
         
        Poly(LL x) {
            c.resize(1);
            c[0] = x;
            return;
        }
 
        Poly* operator = (const LL x) {
            c.resize(1);
            c[0] = x;
            return this;
        }
 
        Poly Reciprocal() {
            return Poly(Polynomial::Reciprocal(c));
        }
 
        Poly Logarithm() {
            return Poly(Polynomial::Logarithm(c));
        }
         
        Poly* _Logarithm() {
            Polynomial::_Logarithm(c);
            return this;
        }
 
        Poly Exponential() {
            return Poly(Polynomial::Exponential(c));
        }
 
        Poly Pow(LL b) {
            return Poly(Polynomial::Pow(c, b));
        }
 
        Poly* Resize(int n) {
            c.resize(n);
            return this;
        }
         
        void _Resize(int n) {
            c.resize(n);
            return;
        }
         
        inline int Deg() const {
            return (int)c.size() - 1;
        }
         
        void Clear() {
            c = vector<LL>(1);
            return;
        }
 
        void Read(int n) {
            c.clear();
            c.resize(n + 1);
            for (int i = 0; i <= n; ++i) {
                scanf("%lld", &c[i]);
                c[i] %= MOD;
            }
            return;
        }
 
        void Print() {
            for (LL x : c)
                printf("%lld ", x);
            puts("");
            return;
        }
         
    };
 
    Poly Read(int n) {
        Poly a;
        a.Read(n);
        return a;
    }
     
    void _Read(int n, Poly &a) {
        a.Read(n);
        return;
    }
 
    void Print(Poly a) {
        a.Print();
        return;
    }
     
    void _Print(Poly &a) {
        a.Print();
        return;
    }
 
    Poly Resize(Poly a, int n) {
        return *a.Resize(n);
    }
 
    Poly* _Resize(Poly &a, int n) {
        a._Resize(n);
        return &a;
    }
 
    Poly Reciprocal(Poly &a) {
        return a.Reciprocal();
    }
 
    Poly Logarithm(Poly &a) {
        return a.Logarithm();
    }
 
    Poly* _Logarithm(Poly &a) {
        return a._Logarithm();
    }
 
    Poly Exponential(Poly &a) {
        return a.Exponential();
    }
 
    Poly Pow(Poly &a, LL b) {
        return a.Pow(b);
    }
 
    Poly operator << (const Poly &a, int n) {
        return Poly(LeftShift(a.c, n));
    }
 
    Poly* operator <<= (Poly &a, int n) {
        _LeftShift(a.c, n);
        return &a;
    }
     
    Poly operator >> (const Poly &a, int n) {
        return Poly(RightShift(a.c, n));
    }
 
    Poly* operator >>= (Poly &a, int n) {
        _RightShift(a.c, n);
        return &a;
    }
     
    Poly operator + (const Poly &a, const Poly &b) {
        return Poly(Add(a.c, b.c));
    }
 
    Poly operator + (const Poly &a, LL b) {
        return Poly(Add(a.c, b));
    }
     
    Poly* operator += (Poly &a, const Poly &b) {
        _Add(a.c, b.c);
        return &a;
    }
     
    Poly* operator += (Poly &a, LL b) {
        _Add(a.c, b);
        return &a;
    }
     
    Poly operator - (const Poly &a, const Poly &b) {
        return Poly(Subtract(a.c, b.c));
    }
     
    Poly operator - (const Poly &a, LL b) {
        return Poly(Subtract(a.c, b));
    }
     
    Poly* operator -= (Poly &a, const Poly &b) {
        _Subtract(a.c, b.c);
        return &a;
    }
     
    Poly* operator -= (Poly &a, LL b) {
        _Subtract(a.c, b);
        return &a;
    }
     
    Poly operator * (const Poly &a, const Poly &b) {
        return Poly(Multiply(a.c, b.c));
    }
     
    Poly operator * (const Poly &a, LL b) {
        return Poly(Multiply(a.c, b));
    }
     
    Poly* operator *= (Poly &a, LL b) {
        _Multiply(a.c, b);
        return &a;
    }
     
    Poly* operator *= (Poly &a, const Poly &b) {
        _Multiply(a.c, b.c);
        return &a;
    }
 
    Poly operator / (const Poly &a, const Poly &b) {
        return Poly(Divide(a.c, b.c));
    }
     
    Poly operator / (const Poly &a, LL b) {
        return Poly(Divide(a.c, b));
    }
     
    Poly* operator /= (Poly &a, const Poly &b) {
        a.c = Divide(a.c, b.c);
        return &a;
    }
     
    Poly* operator /= (Poly &a, LL b) {
        _Divide(a.c, b);
        return &a;
    }
     
    Poly operator % (const Poly &a, const Poly &b) {
        return Poly(Modulo(a.c, b.c));
    }
 
    Poly* operator %= (Poly &a, const Poly &b) {
        a.c = Modulo(a.c, b.c);
        return &a;
    }
     
}
 
using namespace Polynomial;
 
Poly a, b;
 
int main() {
    InitInv();
    int n;
    LL k;
    scanf("%d%lld", &n, &k);
    _Read(n - 1, a);
    Print(Pow(a, k));
    return 0;
}
上一篇:文本检测网络EAST学习(二)


下一篇:多项式模板