[清华集训2017]生成树计数——生成函数

题面

  Bzoj5119

解析

  考虑任一长度为$n-2$的序列,序列中每个数权值为$[1,n]$,这个序列($prufer$序列)唯一对应一棵形态确定的$n$个节点的树,反之亦然,即树和$prufer$序列是双射关系。

  那么可以将问题转化为枚举$prufer$序列:$$\begin{align*}Ans&=\sum_{\sum_{i}d_i=n-2}\frac{(n-2)!}{\prod_id_i!}(\prod_ia_i^{d_i+1}(d_i+1)^m)*(\sum_i(d_i+1)^m)\\&=(n-2)!*(\prod_ia_i)*(\sum_{\sum_{i}d_i=n-2}(\prod_i\frac{a_i^{d_i}(d_i+1)^m}{d_i!})*(\sum_i(d_i+1)^m))\\&=(n-2)!*(\prod_ia_i)*(\sum_{\sum_{i}d_i=n-2}\sum_i((d_i+1)^m*\prod_j\frac{a_j^{d_j}(d_j+1)^m}{d_j!}))\\&=(n-2)!*(\prod_ia_i)*(\sum_{\sum_{i}d_i=n-2}\sum_i(\frac{a_i^{d_i}(d_i+1)^{2m}}{d_i!}*\prod_{j\neq i}\frac{a_j^{d_j}(d_j+1)^m}{d_j!}))\\\end{align*}$$

  设$$A(x)=\sum_{i=0}^{\infty}\frac{(i+1)^{2m}}{i!}x^i\\ B(x)=\sum_{i=0}^{\infty}\frac{(i+1)^{m}}{i!}x^i\\ F(x)=\sum_iA(a_ix)\prod_{j\neq i}B(a_jx)$$

  对$F(x)$化简:$$\begin{align*}F(x)&=\sum_iA(a_ix)\prod_{j\neq i}B(a_jx)\\&=(\sum_i\frac{A(a_ix)}{B(a_ix)})*\prod_iB(a_ix)\\&=(\sum_i\frac{A(a_ix)}{B(a_ix)})*\exp(\ln(\prod_iB(a_ix)))\\&=(\sum_i\frac{A(a_ix)}{B(a_ix)})*\exp(\sum_i\ln(B(a_ix)))\end{align*}$$

  再设$$C(x)=\frac{A(x)}{B(x)}\\ D(x)=\ln(B(x))$$

  有:$$[x^j](\sum_i\ln(B(a_ix))) = ([x^j]D(x))*\sum_ia_i^j\\ ([x^j]\sum_{i}\frac{A(a_ix)}{B(a_ix)})=([x^j]C(x))*\sum_ia_i^j$$

  求出$C(x)$与$D(x)$,对它们的第$i$项乘以$\sum_ja_j^i$,也就是需要求数列的$i$次方和,我在生成函数小结里有写,这里就不展开说了。

  最终答案:$$Ans=(n-2)!*(\prod_ia_i)*[x^{n-2}]F(x)$$

  $O(N\log^2 N)$

 代码:

[清华集训2017]生成树计数——生成函数
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<vector>
#define ls (x << 1)
#define rs ((x << 1) | 1)
using namespace std;
typedef long long ll;
const int maxn = 60005, mod = 998244353, g = 3;

int add(int x, int y)
{
    return x + y < mod? x + y: x + y - mod;
}

int rdc(int x, int y)
{
    return x - y < 0? x - y + mod: x - y;
}

ll qpow(ll x, int y)
{
    ll ret = 1;
    while(y)
    {
        if(y&1)
            ret = ret * x % mod;
        x = x * x % mod;
        y >>= 1;
    }
    return ret;
}

int n, m, lim, bit, rev[maxn<<1], a[maxn];
ll ginv, fac[maxn], fnv[maxn], inv[maxn];
ll A[maxn<<1], B[maxn<<1], c[maxn<<1], d[maxn<<1], ln[maxn<<1], iv[maxn<<1], f[maxn<<1], h[maxn<<1];
vector<int> G[maxn<<1];

void init()
{
    ginv = qpow(g, mod - 2);
    fac[0] = 1;
    for(int i = 1; i <= n; ++i)
        fac[i] = fac[i-1] * i % mod;
    inv[0] = inv[1] = fnv[0] = fnv[1] = 1;
    for(int i = 2; i <= n; ++i)
    {
        inv[i] = (mod - mod / i) * inv[mod%i] % mod;
        fnv[i] = fnv[i-1] * inv[i] % mod;
    }
}

void NTT_init(int x)
{
    lim = 1;
    bit = 0;
    while(lim <= x)
    {
        lim <<= 1;
        ++ bit;
    }
    for(int i = 1; i < lim; ++i)
        rev[i] = (rev[i>>1] >> 1) | ((i & 1) << (bit - 1));
}

void NTT(ll *x, int y)
{
    for(int i = 1; i < lim; ++i)
        if(i < rev[i])
            swap(x[i], x[rev[i]]);
    ll wn, w, u, v;
    for(int i = 1; i < lim; i <<= 1)
    {
        wn = qpow((y == 1)? g: ginv, (mod - 1) / (i << 1));
        for(int j = 0; j < lim; j += (i << 1))
        {
            w = 1;
            for(int k = 0; k < i; ++k)
            {
                u = x[j+k];
                v = x[j+k+i] * w % mod;
                x[j+k] = add(u, v);
                x[j+k+i] = rdc(u, v);
                w = w * wn % mod;
            }
        }
    }
    if(y == -1)
    {
        ll linv = qpow(lim, mod - 2);
        for(int i = 0; i < lim; ++i)
            x[i] = x[i] * linv % mod;
    }
}

void get_inv(ll *x, ll *y, int len)
{
    if(len == 1)
    {
        x[0] = qpow(y[0], mod - 2);
        return ;
    }
    get_inv(x, y, (len + 1) >> 1);
    for(int i = 0; i < len; ++i)
        c[i] = y[i];
    NTT_init(len << 1);
    NTT(x, 1);
    NTT(c, 1);
    for(int i = 0; i < lim; ++i)
    {
        x[i] = x[i] * rdc(2, c[i] * x[i] % mod) % mod;
        c[i] = 0;
    }
    NTT(x, -1);
    for(int i = len; i < lim; ++i)
        x[i] = 0;
}

void get_ln(ll *x, ll *y, int len)
{
    for(int i = 0; i < len; ++i)
        x[i] = y[i+1] * (i + 1) % mod;
    get_inv(iv, y, len);
    NTT_init(len << 1);
    NTT(x, 1);
    NTT(iv, 1);
    for(int i = 0; i < lim; ++i)
    {
        x[i] = x[i] * iv[i] % mod;
        iv[i] = 0;
    }
    NTT(x, -1);
    for(int i = len - 1; i >= 1; --i)
        x[i] = x[i-1] * inv[i] % mod;
    x[0] = 0;
    for(int i = len; i < lim; ++i)
        x[i] = 0;
}

void get_exp(ll *x, ll *y, int len)
{
    if(len == 1)
    {
        x[0] = 1;
        return ;
    }
    get_exp(x, y, (len + 1) >> 1);
    get_ln(ln, x, len);
    for(int i = 0; i < len; ++i)
    {
        c[i] = add(i == 0, rdc(y[i], ln[i]));
        ln[i] = 0;
    }
    NTT_init(len << 1);
    NTT(x, 1);
    NTT(c, 1);
    for(int i = 0; i < lim; ++i)
    {
        x[i] = x[i] * c[i] % mod;
        c[i] = 0;
    }
    NTT(x, -1);
    for(int i = len; i < lim; ++i)
        x[i] = 0;
}

void solve(int x, int l, int r, int *y)
{
    if(l == r)
    {
        G[x].push_back(1);
        G[x].push_back(rdc(0, y[l]));
        return; 
    }
    int mid = (l + r) >> 1;
    solve(ls, l, mid, y);
    solve(rs, mid + 1, r, y);
    for(int i = 0; i <= mid - l + 1; ++i)
        c[i] = G[ls][i];
    for(int i = 0; i <= r - mid; ++i)
        d[i] = G[rs][i];
    NTT_init(r - l + 1);
    NTT(c, 1);
    NTT(d, 1);
    for(int i = 0; i < lim; ++i)
    {
        c[i] = c[i] * d[i] % mod;
        d[i] = 0;
    }
    NTT(c, -1);
    for(int i = 0; i <= r - l + 1; ++i)
    {
        G[x].push_back(c[i]);
        c[i] = 0;
    }
    for(int i = r - l + 2; i < lim; ++i)
        c[i] = 0;
}

int main()
{
    scanf("%d%d", &n, &m);
    init();
    ll ans = fac[n-2];
    for(int i = 1; i <= n; ++i)
    {
        scanf("%d", &a[i]);
        ans = ans * a[i] % mod;
    }
    solve(1, 1, n, a);
    for(int i = 0; i <= n; ++i)
        d[i] = G[1][i];
    get_ln(f, d, n + 1);
    for(int i = n; i >= 1; --i)
    {
        //f[i] = f[i] * i % mod;
        f[i] = rdc(0, f[i] * i % mod);
        d[i] = 0;
    }
    f[0] = n;

    ll tmp;
    for(int i = 0; i <= n; ++i)
    {
        tmp = qpow(i + 1, m);
        B[i] = tmp * fnv[i] % mod;
        A[i] = B[i] * tmp % mod;
    }
    get_ln(d, B, n + 1);
    get_inv(h, B, n + 1);
    NTT_init(n << 1);
    NTT(A, 1);
    NTT(h, 1);
    for(int i = 0; i < lim; ++i)
        A[i] = A[i] * h[i] % mod;
    NTT(A, -1);
    for(int i = 0; i <= n; ++i)
    {
        A[i] = A[i] * f[i] % mod;
        d[i] = d[i] * f[i] % mod;
    }
    for(int i = n + 1; i < lim; ++i)
        A[i] = 0;
    memset(B, 0, sizeof(B));
    get_exp(B, d, n + 1);
    NTT_init(n << 1);
    NTT(A, 1);
    NTT(B, 1);
    for(int i = 0; i < lim; ++i)
        A[i] = A[i] * B[i] % mod;
    NTT(A, -1);
    printf("%lld", ans * A[n-2] % mod);
    return 0;
}
View Code
上一篇:Linux命令:ln


下一篇:linux运维基础[linux常用命令]——————软、硬链接与ln命令