题面
题解
上了文化课之后终于知道“超几何分布”的准确定义了,这时候再回来看这题,突然灵光一闪,想到了一个新的解法。
超几何分布:\(n\) 个物品中,\(m\) 个次品,不放回抽取的 \(k\) 个物品中有 \(x\) 个次品的概率 \(P(x = i) = \dfrac {\binom mi \binom {n - m} {k - i}} {\binom nk}\)。
那么其概率生成函数为 \(P(x) = \dfrac {\sum_{i} \binom mi \binom{n-m}{k-i} x^i} {\binom nk}\)。
加入辅助变量 \(y\),可得 \(P(x) = \dfrac {[y^k] (1 + xy)^m(1 + y)^{n - m}} {\binom nk}\)
由定义可知,\(E(x^{\underline k}) = P^{(k)}(1)\)。
所以 \(E(x^L) = \sum_{i=0}^L \begin{Bmatrix} L \\ i \end{Bmatrix} E(x^{\underline i}) = \sum_{i=0}^L \begin{Bmatrix} L \\ i \end{Bmatrix} P^{(i)}(1)\),那么问题就变成了求 \(P^{(i)}(1)\)。
\(P^{(i)}(x) = \dfrac {[y^k] m^{\underline i}y^i (1 + xy)^{m - i} (1 + y)^{n - m}} {\binom nk}\),所以 \(P^{(i)}(1) = \dfrac {[y^{k - i}] m^{\underline i} (1 + y)^{n - i}} {\binom nk} = \dfrac {m^{\underline i} \binom{n - i}{k - i}} {\binom nk} = \dfrac {k^{\underline i}m^{\underline i}} {n^{\underline i}}\),带到之前那个期望的式子里面去就可以了。
代码
#include <cstdio>
#include <algorithm>
#include <vector>
#define file(x) freopen(#x".in", "r", stdin), freopen(#x".out", "w", stdout)
inline int read()
{
int data = 0, w = 1; char ch = getchar();
while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
if (ch == '-') w = -1, ch = getchar();
while (ch >= '0' && ch <= '9') data = data * 10 + (ch ^ 48), ch = getchar();
return data * w;
}
const int maxn(6e5 + 10), M(2e7), maxm(M + 10), Mod(998244353);
int fastpow(int x, int y)
{
int ans = 1;
for (; y; y >>= 1, x = 1ll * x * x % Mod)
if (y & 1) ans = 1ll * ans * x % Mod;
return ans;
}
int r[maxn], w[maxn];
void FFT(int *p, int N)
{
for (int i = 0; i < N; i++) if (i < r[i]) std::swap(p[i], p[r[i]]);
for (int i = 1, s = 2, t = N >> 1; i < N; s <<= 1, t >>= 1, i <<= 1)
for (int j = 0; j < N; j += s) for (int k = 0, o = 0; k < i; ++k, o += t)
{
int x = p[j + k], y = 1ll * w[o] * p[i + j + k] % Mod;
p[j + k] = (x + y) % Mod, p[i + j + k] = (x - y + Mod) % Mod;
}
}
int n, m, N, P, T, L, A[maxn], S[maxn], fac[maxm], inv[maxm];
int main()
{
#ifndef ONLINE_JUDGE
file(cpp);
#endif
n = read(), m = read(), T = read(), L = read();
fac[0] = inv[0] = 1; int max = std::max(n, L);
for (int i = 1; i <= max; i++) fac[i] = 1ll * fac[i - 1] * i % Mod;
inv[max] = fastpow(fac[max], Mod - 2);
for (int i = max - 1; i; i--) inv[i] = 1ll * inv[i + 1] * (i + 1) % Mod;
for (int i = 0, o = 1; i <= L; i++, o = Mod - o) A[i] = 1ll * o * inv[i] % Mod;
for (int i = 0; i <= L; i++) S[i] = 1ll * fastpow(i, L) * inv[i] % Mod;
for (N = 1, P = -1; N <= (L << 1); N <<= 1, ++P);
for (int i = 0; i < N; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << P);
w[0] = 1, w[1] = fastpow(3, (Mod - 1) / N);
for (int i = 2; i < N; i++) w[i] = 1ll * w[i - 1] * w[1] % Mod;
FFT(A, N), FFT(S, N), w[1] = fastpow(332748118, (Mod - 1) / N);
for (int i = 0; i < N; i++) S[i] = 1ll * S[i] * A[i] % Mod;
for (int i = 2; i < N; i++) w[i] = 1ll * w[i - 1] * w[1] % Mod;
FFT(S, N); int invn = fastpow(N, Mod - 2);
for (int i = 0; i < N; i++) S[i] = 1ll * S[i] * invn % Mod;
while (T--)
{
int _n = read(), _m = read(), _k = read(), ans = 0;
int lim = std::min(L, std::min(_m, std::min(_n, _k)));
for (int i = 0; i <= lim; i++)
ans = (ans + 1ll * S[i] * inv[_m - i] % Mod * inv[_k - i] % Mod * fac[_n - i]) % Mod;
ans = 1ll * ans * fac[_m] % Mod * fac[_k] % Mod * inv[_n] % Mod;
printf("%d\n", ans);
}
return 0;
}