CF923E Perpetual Subtraction 题解

洛谷传送门

f(τ,i)f(\tau,i)f(τ,i)表示τ\tauτ次操作进行之前,也就是第(τ1)(\tau-1)(τ−1)次操作进行之后得到的数是iii的概率。那么f(1,i)=pif(1,i)=p_if(1,i)=pi​,我们需要对于每个iii求出f(m+1,i)f(m+1,i)f(m+1,i)。

不要问我为什么用τ\tauτ这个奇怪的字母

考虑一次操作带来的影响:显然有

f(τ+1,i)=j=inf(τ,j)j+1f(\tau+1,i)=\sum\limits_{j=i}^n\frac{f(\tau,j)}{j+1}f(τ+1,i)=j=i∑n​j+1f(τ,j)​

看起来不是很明显,我们利用生成函数来找关系。设Fτ(x)=i=0nf(τ,i)xiF_\tau(x)=\sum\limits_{i=0}^nf(\tau,i)x^iFτ​(x)=i=0∑n​f(τ,i)xi,那么

Fτ+1(x)=i=0nf(τ+1,i)xiF_{\tau+1}(x)=\sum\limits_{i=0}^nf(\tau+1,i)x^iFτ+1​(x)=i=0∑n​f(τ+1,i)xi

=i=0nj=inf(τ,j)j+1xi=\sum\limits_{i=0}^n\sum\limits_{j=i}^n\frac{f(\tau,j)}{j+1}x^i=i=0∑n​j=i∑n​j+1f(τ,j)​xi

=j=0nf(τ,j)j+1i=0jxi=\sum\limits_{j=0}^n\frac{f(\tau,j)}{j+1}\sum\limits_{i=0}^jx^i=j=0∑n​j+1f(τ,j)​i=0∑j​xi

=j=0nf(τ,j)j+1xj+11x1=\sum\limits_{j=0}^n\frac{f(\tau,j)}{j+1}\cdot\frac{x^{j+1}-1}{x-1}=j=0∑n​j+1f(τ,j)​⋅x−1xj+1−1​

=1x1j=0nf(τ,j)xj+11j+1=\frac{1}{x-1}\sum\limits_{j=0}^nf(\tau,j)\frac{x^{j+1}-1}{j+1}=x−11​j=0∑n​f(τ,j)j+1xj+1−1​

=1x1j=0nf(τ,j)1xtjdt=\frac{1}{x-1}\sum\limits_{j=0}^nf(\tau,j)\int_1^xt^j\mathrm{d}t=x−11​j=0∑n​f(τ,j)∫1x​tjdt

=1x11xFτ(t)dt=\frac{1}{x-1}\int_1^xF_\tau(t)\mathrm{d}t=x−11​∫1x​Fτ​(t)dt

我们喜闻乐见的是0xf(t)dt=f(x)dx\int_0^xf(t)\mathrm{d}t=\int f(x)\mathrm{d}x∫0x​f(t)dt=∫f(x)dx并且常数项为000。但这里积分下限是111,不好处理。

Gτ(x)=Fτ(x+1)G_\tau(x)=F_\tau(x+1)Gτ​(x)=Fτ​(x+1),并且Gτ(x)=i=0ng(τ,i)xiG_\tau(x)=\sum\limits_{i=0}^ng(\tau,i)x^iGτ​(x)=i=0∑n​g(τ,i)xi,那么

Gτ+1(x)=1x1x+1Fτ(t)dtG_{\tau+1}(x)=\frac{1}{x}\int_1^{x+1}F_\tau(t)\mathrm{d}tGτ+1​(x)=x1​∫1x+1​Fτ​(t)dt

=1x0xFτ(t+1)d(t+1)=\frac{1}{x}\int_0^xF_\tau(t+1)\mathrm{d}(t+1)=x1​∫0x​Fτ​(t+1)d(t+1)

=i=0ng(τ,i)i+1xi=\sum\limits_{i=0}^n\frac{g(\tau,i)}{i+1}x^i=i=0∑n​i+1g(τ,i)​xi

于是我们发现

g(τ+1,i)=g(τ,i)i+1g(\tau+1,i)=\frac{g(\tau,i)}{i+1}g(τ+1,i)=i+1g(τ,i)​

这是等比数列。那么

g(m+1,i)=g(1,i)(i+1)mg(m+1,i)=\frac{g(1,i)}{(i+1)^m}g(m+1,i)=(i+1)mg(1,i)​

下面的问题就是如何求出g(1,i)g(1,i)g(1,i),以及如何由g(m+1,i)g(m+1,i)g(m+1,i)求出f(m+1,i)f(m+1,i)f(m+1,i)。

Gτ(x)=Fτ(x+1)\because G_\tau(x)=F_\tau(x+1)∵Gτ​(x)=Fτ​(x+1)

i=0ng(τ,i)xi=i=0nf(τ,i)j=0iCijxj\therefore \sum\limits_{i=0}^ng(\tau,i)x^i=\sum\limits_{i=0}^nf(\tau,i)\sum\limits_{j=0}^iC_i^jx^j∴i=0∑n​g(τ,i)xi=i=0∑n​f(τ,i)j=0∑i​Cij​xj

=j=0nxji=jnf(τ,i)Cij=\sum\limits_{j=0}^nx^j\sum\limits_{i=j}^nf(\tau,i)C_i^j=j=0∑n​xji=j∑n​f(τ,i)Cij​

g(τ,i)=j=inCjif(τ,j)\therefore g(\tau,i)=\sum\limits_{j=i}^nC_j^if(\tau,j)∴g(τ,i)=j=i∑n​Cji​f(τ,j)

τ=1\tau=1τ=1,这时f(τ,i)=pif(\tau,i)=p_if(τ,i)=pi​,那么

g(1,i)=j=inj!pji!(ji)!g(1,i)=\sum\limits_{j=i}^n\frac{j!p_j}{i!(j-i)!}g(1,i)=j=i∑n​i!(j−i)!j!pj​​

=1i!j=0ni1j!(j+i)!pj+i=\frac{1}{i!}\sum\limits_{j=0}^{n-i}\frac{1}{j!}(j+i)!p_{j+i}=i!1​j=0∑n−i​j!1​(j+i)!pj+i​

ai=1i!,bi=(ni)!pnia_i=\frac{1}{i!},b_i=(n-i)!p_{n-i}ai​=i!1​,bi​=(n−i)!pn−i​,那么

g(1,i)=1i!j=0niajbnijg(1,i)=\frac{1}{i!}\sum\limits_{j=0}^{n-i}a_jb_{n-i-j}g(1,i)=i!1​j=0∑n−i​aj​bn−i−j​

NTT求卷积即可。

然后再考虑怎么求出f(m+1,i)f(m+1,i)f(m+1,i)。为了方便起见,将f(m+1,i)f(m+1,i)f(m+1,i)记为fif_ifi​,将g(m+1,i)g(m+1,i)g(m+1,i)记为gig_igi​。

我们发现由fff求出ggg的过程就是个卷积,现在要反过来求,我们需要多项式求逆吗?并不需要,只要一个反演即可。

gi=j=inCjifj\because g_i=\sum\limits_{j=i}^nC_j^if_j∵gi​=j=i∑n​Cji​fj​

fi=j=in(1)jiCjigj\therefore f_i=\sum\limits_{j=i}^n(-1)^{j-i}C_j^ig_j∴fi​=j=i∑n​(−1)j−iCji​gj​

=1i!j=in(1)ji(ji)!j!gj=\frac{1}{i!}\sum\limits_{j=i}^n\frac{(-1)^{j-i}}{(j-i)!}j!g_j=i!1​j=i∑n​(j−i)!(−1)j−i​j!gj​

=1i!j=0ni(1)jj!(j+i)!gj+i=\frac{1}{i!}\sum\limits_{j=0}^{n-i}\frac{(-1)^j}{j!}(j+i)!g_{j+i}=i!1​j=0∑n−i​j!(−1)j​(j+i)!gj+i​

ci=(1)ii!,di=(ni)!gnic_i=\frac{(-1)^i}{i!},d_i=(n-i)!g_{n-i}ci​=i!(−1)i​,di​=(n−i)!gn−i​,那么

fi=1i!j=0nicjdnijf_i=\frac{1}{i!}\sum\limits_{j=0}^{n-i}c_jd_{n-i-j}fi​=i!1​j=0∑n−i​cj​dn−i−j​

NTT求卷积即可。

时间复杂度O(nlogn)O(n\log n)O(nlogn)。

#include <cctype>
#include <cstdio>
#include <climits>
#include <algorithm>

template <typename T> inline void read(T& x) {
    int f = 0, c = getchar(); x = 0;
    while (!isdigit(c)) f |= c == '-', c = getchar();
    while (isdigit(c)) x = x * 10 + c - 48, c = getchar();
    if (f) x = -x;
}
template <typename T, typename... Args>
inline void read(T& x, Args&... args) {
    read(x); read(args...); 
}
template <typename T> void write(T x) {
    if (x < 0) x = -x, putchar('-');
    if (x > 9) write(x / 10);
    putchar(x % 10 + 48);
}
template <typename T> inline void writeln(T x) { write(x); puts(""); }
template <typename T> inline bool chkmin(T& x, const T& y) { return y < x ? (x = y, true) : false; }
template <typename T> inline bool chkmax(T& x, const T& y) { return x < y ? (x = y, true) : false; }

typedef long long LL;

const LL mod = 998244353, G = 3, Gi = 332748118;
const int maxn = 1e5 + 7;

inline LL qpow(LL x, LL k) {
    LL s = 1;
    for (; k; x = x * x % mod, k >>= 1)
        if (k & 1) s = s * x % mod;
    return s;
}

inline void ntt(LL *A, int *r, int lim, int tp) {
    for (int i = 0; i < lim; ++i)
        if (i < r[i]) std::swap(A[i], A[r[i]]);
    for (int mid = 1; mid < lim; mid <<= 1) {
        LL wn = qpow(tp == 1 ? G : Gi, (mod - 1) / (mid << 1));
        for (int j = 0; j < lim; j += mid << 1) {
            LL w = 1;
            for (int k = 0; k < mid; ++k, w = w * wn % mod) {
                LL x = A[j + k], y = w * A[j + k + mid] % mod;
                A[j + k] = (x + y) % mod;
                A[j + k + mid] = (x - y + mod) % mod;
            }
        }
    }
    if (tp == -1) {
        LL inv = qpow(lim, mod - 2);
        for (int i = 0; i < lim; ++i)
            A[i] = A[i] * inv % mod;
    }
}

int n, r[maxn << 2], lim, l;
LL m, p[maxn], fac[maxn], ifac[maxn];
LL a[maxn << 2], b[maxn << 2], g[maxn];

int main() {
    read(n, m);
    for (int i = 0; i <= n; ++i) read(p[i]);
    fac[0] = ifac[0] = 1;
    for (int i = 1; i <= n + 1; ++i)
        fac[i] = fac[i - 1] * i % mod;
    ifac[n + 1] = qpow(fac[n + 1], mod - 2);
    for (int i = n; i; --i)
        ifac[i] = ifac[i + 1] * (i + 1) % mod;
    for (int i = 0; i <= n; ++i) {
        a[i] = ifac[i];
        b[i] = fac[n - i] * p[n - i] % mod;
    }
    for (lim = 1; lim <= (n << 1); ++l) lim <<= 1;
    for (int i = 0; i < lim; ++i)
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
    ntt(a, r, lim, 1); ntt(b, r, lim, 1);
    for (int i = 0; i < lim; ++i)
        a[i] = a[i] * b[i] % mod;
    ntt(a, r, lim, -1);
    for (int i = 0; i <= n; ++i)
        g[i] = ifac[i] * a[n - i] % mod * qpow(qpow(i + 1, m), mod - 2) % mod;
    for (int i = 0; i <= n; ++i) {
        a[i] = ((i & 1) ? mod - 1 : 1ll) * ifac[i] % mod;
        b[i] = fac[n - i] * g[n - i] % mod;
    }
    for (int i = n + 1; i < lim; ++i)
        a[i] = b[i] = 0;
    ntt(a, r, lim, 1); ntt(b, r, lim, 1);
    for (int i = 0; i < lim; ++i)
        a[i] = a[i] * b[i] % mod;
    ntt(a, r, lim, -1);
    for (int i = 0; i <= n; ++i)
        write(ifac[i] * a[n - i] % mod), putchar(' ');
    return 0;
}
上一篇:2-8-10-16进制详解


下一篇:CS入门学习笔记6-MIT 6.00.1x