二项式反演 学习笔记

概念

二项式反演其实就是利用容斥的思想处理一些通过求 至少或至多 来解决 恰好 的问题。

形式

\[\begin{align*} f(n)=\sum_{i=0}^n(-1)^i\binom n i g(i)&\iff g(n)=\sum_{i=0}^n(-1)^i\binom n i f(i) \\ f(n)=\sum_{i=0}^n\binom n i g(i)&\iff g(n)=\sum_{i=0}^n(-1)^{n-i}\binom n i f(i) \\ f(n)=\sum_{i=n}^m \binom i n g(i)&\iff g(n)=\sum_{i=n}^m(-1)^{i-n}\binom i n f(i) \end{align*} \]

其中,形式三比较常用,组合意义为 \(f(n)\) 表示“至少选 \(n\) 个”,\(g(n)\) 表示“恰好选 \(n\) 个”。

例题

Luogu P4859 已经没有什么好害怕的了

Link

Description

给定两个长为 \(n\) 的序列 \(a,b\),它们两两配对,求配对后 \(a>b\) 的组数比 \(b>a\) 的组数恰好多 \(k\) 组的方案数。

\(1\le n \le 2000,0\le k\le n\)

Solution

题目要求“恰好多 \(k\) 组”,共有 \(n\) 组,所以相当于 \(a>b\) 恰好 \(\dfrac {n+k}2\) 组。

设 \(dp_{i,j}\) 表示前 \(i\) 个数中,有 \(j\) 组 \(a>b\) 的方案数,转移方程为

\[dp_{i,j}=dp_{i-1,j}+dp_{i-1,j-1}\times(cnt_i-(j-1)) \]

其中,\(cnt_i\) 表示 \(b\) 中比 \(a_i\) 小的数的个数,这个可以将 \(a,b\) 排序后双指针扫

接下来,记 \(f_i=dp_{n,i}\times (n-i)!\),也就是至少 \(i\) 组的方案数

然后根据二项式反演就可以得到恰好 \(k\) 组的方案数 \(g_k\)

\[g_k=\sum_{i=k}^n(-1)^{n-i}\binom i k f_i \]

Code
int n, k, a[N], b[N], cnt[N];
ll fac[N], dp[N][N], f[N], g[N];

ll qpow(ll a, int b)
{
    ll res = 1;
    while(b)
    {
        if(b & 1) res = res * a % mod;
        a = a * a % mod, b >>= 1;
    }
    return res;
}
ll add(ll x) {return x < mod ? x : x - mod;}
ll inv(ll x) {return qpow(x, mod - 2);}
ll C(int n, int m) {return n < m ? 0 : fac[n] * inv(fac[m]) % mod * inv(fac[n - m]) % mod;}

int main()
{
    read(n), read(k);

    if((n + k) & 1)
    {
        puts("0");
        return 0;
    }
    k = (n + k) >> 1;

    for(int i = 1; i <= n; i++) read(a[i]);
    for(int i = 1; i <= n; i++) read(b[i]);
    sort(a + 1, a + 1 + n);
    sort(b + 1, b + 1 + n);

    fac[0] = 1;
    for(int i = 1, j = 1; i <= n; i++)
    {
        while(j <= n && a[i] > b[j]) j++;
        cnt[i] = j - 1;
        fac[i] = fac[i - 1] * i % mod;
    }

    dp[0][0] = 1;
    for(int i = 1; i <= n; i++)
        for(int j = 0; j <= i; j++)
            dp[i][j] = add(dp[i - 1][j] + (!j ? 0 : dp[i - 1][j - 1] * (cnt[i] - j + 1) % mod));
    for(int i = 0; i <= n; i++) f[i] = dp[n][i] * fac[n - i] % mod;
    for(int i = 1; i <= n; i++)
        for(int j = k; j <= n; j++)
            g[i] = add(g[i] + add((((j - k) & 1) ? -1 : 1) * f[j] * C(j, k) % mod + mod));

    write(g[k]), pc('\n');
    return 0;
} 
// A.S.

Luogu P4491 [HAOI2018]染色

Link

Description

有一个长为 \(n\) 的序列,每个位置都可以是 \([1,m]\) 中的某一个数,若这 \(n\) 个数中恰好出现了 \(s\) 次的数有 \(k\) 个,那么会得到 \(w_k\) 的贡献。

求对于所有可能的情况,能获得的权值的和对 \(1004535809\) 取模的结果是多少。

\(1\le n\le 10^7,1\le m \le 10^5,0\le s\le 150,0\le w_i\le 1004535809\)

Solution

显然数的个数不会超过 \(cnt=\min(m,n/s)\)

依然是恰好出现 \(s\) 次,考虑计算有 \(i\) 个数至少出现 \(s\) 次的方案数 \(f_i\)

钦定 \(i\) 个数出现了 \(s\) 次,剩下的 \(n-is\) 个位置在 \(m-i\) 个数中随便选

\[f_i=\binom m i \dfrac{n!}{(s!)^i(n-is)!}(m-i)^{n-is} \]

然后进行二项式反演,设 \(g_k\) 表示有 \(k\) 个数恰好出现 \(s\) 次

\[\begin{align*} g_k&=\sum_{i=k}^m(-1)^{i-k}\binom i k f_i \\ g_k\times k!&=\sum(-1)^{i-k}\dfrac{i!}{(i-k)!}f_i \end{align*} \]

到这里就能看出来卷积的形式了

\[F(x)=\sum_{i=0}^mf_i\times i! \\ G(x)=\sum_{i=0}^m\dfrac{(-1)^i}{i!} \]

那么 \(g_i=\dfrac{(F*G)(i)}{i!}\)

NTT 计算卷积即可

Code
#include <bits/stdc++.h>
#define ll long long
#define db double
#define gc getchar
#define pc putchar

using namespace std;

namespace IO
{
    template <typename T>
    void read(T &x)
    {
        x = 0; bool f = 0; char c = gc();
        while(!isdigit(c)) f |= c == '-', c = gc();
        while(isdigit(c)) x = x * 10 + c - '0', c = gc();
        if(f) x = -x;
    }

    template <typename T>
    void write(T x)
    {
        if(x < 0) pc('-'), x = -x;
        if(x > 9) write(x / 10);
        pc('0' + x % 10);
    }
}
using namespace IO;

const int MAXN = 1e7 + 5;
const int N = 1e5 + 5;
const int mod = 1004535809;
const int G = 3;
const int Gi = 334845270;

ll add(ll x) {return x < mod ? x : x - mod;}
ll sub(ll x) {return x < 0 ? x + mod : x;}
ll qpow(ll a, int b)
{
    ll res = 1;
    while(b)
    {
        if(b & 1) res = res * a % mod;
        a = a * a % mod, b >>= 1;
    }
    return res;
}
ll inv(ll x) {return qpow(x, mod - 2);}

ll fac[MAXN], f[N << 2], g[N << 2];

ll C(int n, int m)
{
    return n < m ? 0 : fac[n] * inv(fac[m]) % mod * inv(fac[n - m]) % mod;
}

int rev[N << 2];

int calclim(int n)
{
    int lim = 1;
    while(lim < n) lim <<= 1;
    return lim;
}

void calcrev(int lim)
{
    for(int i = 0; i < lim; i++)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * (lim >> 1));
}

void NTT(ll *a, int lim, int type)
{
    for(int i = 0; i < lim; i++)
        if(i < rev[i]) swap(a[i], a[rev[i]]);
    for(int mid = 1; mid < lim; mid <<= 1)
    {
        ll wn = qpow(type == 1 ? G : Gi, (mod - 1) / (mid << 1));
        for(int i = 0; i < lim; i += (mid << 1))
        {
            ll w = 1;
            for(int j = 0; j < mid; j++, w = w * wn % mod)
            {
                ll x = a[i + j], y = w * a[i + mid + j] % mod;
                a[i + j] = add(x + y);
                a[i + mid + j] = sub(x - y);
            }
        }
    }
    if(type == -1)
    {
        ll limi = qpow(lim, mod - 2);
        for(int i = 0; i < lim; i++) a[i] = a[i] * limi % mod;
    }
    return;
}

int main()
{
    int n, m, s;
    read(n), read(m), read(s);
    int cnt = min(m, n / s) + 1;
    fac[0] = 1;
    for(int i = 1; i < MAXN; i++)
        fac[i] = fac[i - 1] * i % mod;
    for(int i = 0; i < cnt; i++)
    {
        f[i] = fac[i] * C(m, i) % mod * fac[n] % mod * inv(qpow(fac[s], i)) % mod * inv(fac[n - s * i]) % mod * qpow(m - i, n - s * i) % mod;
        g[i] = (i & 1) ? mod - inv(fac[i]) : inv(fac[i]);
    }

    reverse(f, f + cnt);
    int lim = calclim(cnt << 1);
    calcrev(lim);
    NTT(f, lim, 1), NTT(g, lim, 1);
    for(int i = 0; i < lim; i++) f[i] = f[i] * g[i] % mod;
    NTT(f, lim, -1);
    reverse(f, f + cnt);
    
    ll ans = 0;
    for(int i = 0, w; i < cnt; i++)
        read(w), ans = add(ans + inv(fac[i]) * f[i] % mod * w % mod);
    write(ans), pc('\n');

    return 0;
}
// A.S.
上一篇:【做题记录】 ZJOI2009 假期的宿舍


下一篇:如何测试手机上的SOAP客户端