B
考虑集合不好算,先算一个长为 \(k\) 的序列(可以重复)的方案数,然后容斥出集合的方案。
第一部分:计算序列个数。
称一个可重集为一个「块」当且仅当其所有元素在 \([c\cdot 2^k,c\cdot 2^{k+1})\) 之间且所有元素的出现次数相同。于是 \([0,n]\) 可以分解成 \(\mathcal O(\log n)\) 个块。容易发现如下两个性质:
-
从一个块中选择若干个元素异或起来,得到的还是一个大小相同的块
-
从两个块中选择各若干个(非空个)元素异或起来,得到的还是一个块,且大小等于较大的那个
于是考虑枚举一个块 \((c,2^t)\),然后统计所有元素中最大的那个出现在这个块中的方案数。注意到此时所有小于这个块的数全都是等价的,只需要统计其个数 \(sum\),然后容易写出如下的柿子:
\[f_k\gets f_k+\sum_{x\geq 1}\frac{\binom{b}{b-h_x}}{2^t} \binom kx 2^{tx} sum^{k-x} \]其意义即为从当前块中选取 \(x\) 个元素,然后从后面选取 \(k-x\) 个元素异或起来,其中恰好有 \(h_x\) 个 \(1\) 的我们统计进答案,其中 \(h_x\) 为一个只和 \(x\) 奇偶性有关的量,其表示了二进制大于 \(t\) 的那些位置的 \(1\) 的个数。
然后我们要对每个 \(k\) 计算这个柿子,暴力复杂度是 \(\mathcal O(k^2\log n)\) 的,但是注意到这个柿子只和 \(x\) 奇偶性有关,于是考虑计算 \((A+Bx)^k\bmod (x^2-1)\) 的值,这可用长为 \(2\) 的 \(NTT\) 快速进行计算,于是复杂度优化成了 \(\mathcal O(k\log n)\)
第二部分:容斥。
考虑记 \(g_i\) 为 \(i\) 的答案,于是枚举一个 \(j<i\),然后计算去重后变成 \(j\) 的序列个数。相当于 \(j\) 个元素要求出现奇数次,\(n-j+1\) 个元素要求出现 \(n-j+1\) 次,可以写出如下的两个多项式:
\[\left\{ \begin{aligned} & \sinh(x)=\frac{e^x-e^{-x}}{2}=\sum_{i\geq 0}\frac{x^{2i+1}}{(2i+1)!}\\ & \cosh(x)=\frac{e^x+e^{-x}}{2}=\sum_{i\geq 0}\frac{x^{2i}}{(2i)!}\\ \end{aligned} \right. \]于是要求的容斥系数就是 \([x^i]\sinh^j(x)\cosh^{n-j+1}(x)\)。注意到这个函数求导后变成 \(j\sinh^{j-1}(x)\cosh^{n-j+1}(x)+(n-j+1)\sinh^j(x)\cosh^{n-j}(x)\),于是可以递归到子问题。
这部分的复杂度为 \(\mathcal O(k^2)\),总复杂度为 \(\mathcal O(k(k+\log n))\)
#include <ctime>
#include <cstdio>
#define nya(neko...) fprintf(stderr, neko)
__attribute__((destructor))
inline void ptime() {
nya("\nTime: %.3lf(s)\n", 1. * clock() / CLOCKS_PER_SEC);
}
#include <algorithm>
using ll = long long;
constexpr ll mod = 998244353;
constexpr ll inv2 = (mod + 1) / 2;
inline ll fsp(ll a, ll b, ll res = 1) {
for(a %= mod; b; a = a * a % mod, b >>= 1)
b & 1 ? res = res * a % mod : 0; return res;
}
#define ppcnt(x) __builtin_popcount(x)
int n, k, b, cnt;
struct Block { int k, c; } blk[100];
inline ll binom(ll n, int k) {
ll fz = 1, fm = 1;
for(int i = 1; i <= k; ++i)
fm = fm * i % mod, fz = fz * (n - i + 1) % mod;
return fsp(fm, mod - 2, fz);
}
constexpr int maxk = 5005;
ll C[100][100], coef[2][maxk];
ll f[maxk], g[maxk];
int main() {
for(int i = 0; i <= 50; ++i) {
C[i][0] = 1;
for(int j = 1; j <= i; ++j)
C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % mod;
}
scanf("%d%d%d", &n, &k, &b);
blk[++cnt] = { 0, n };
for(int i = 0; i <= 30; ++i) if(n >> i & 1)
blk[++cnt] = { i, n >> i + 1 << i + 1 };
if(ppcnt(n) == b) for(int i = 0; i <= k; ++i) if(i & 1) ++f[i];
if(!b) for(int i = 0; i <= k; ++i) if(!(i & 1)) ++f[i];
f[0] = !b;
for(int i = 2, sum = 1; i <= cnt; ++i) {
ll A = 1 << blk[i].k, B = sum;
ll w[2] = { 1, 1 }, invA = fsp(A, mod - 2);
for(int k = 1; k <= ::k; ++k) {
w[0] = w[0] * (A + B) % mod;
w[1] = w[1] * (A - B) % mod;
for(int x : { 0, 1 }) {
int o = k - x & 1 ? blk[i].c : 0;
int z = x & 1 ? blk[i].c + (1 << blk[i].k) : 0;
int t = b - ppcnt(o ^ z);
ll fa = (!x ? w[0] + w[1] : w[0] - w[1]) * inv2 % mod * invA % mod;
if(t >= 0) (f[k] += C[blk[i].k][t] * fa) %= mod;
if(!(k - x & 1)) {
(f[k] -= invA * fsp(B, k) % mod * C[blk[i].k][t]) %= mod;
}
}
}
sum += 1 << blk[i].k;
}
// g_i <- g_i - g_j [x^i]((e^x - e^{-x}) / 2)^j((e^x + e^{-x}) / 2)^{n + 1 - j}
int cur = 0;
coef[0][0] = 1, g[0] = f[0];
ll fac = 1;
for(int i = 1; i <= k; ++i) {
fac = fac * i % mod;
cur ^= 1;
for(int j = 0; j <= i; ++j) {
coef[cur][j] = j ? j * coef[cur ^ 1][j - 1] % mod : 0;
(coef[cur][j] += (n - j + 1) * coef[cur ^ 1][j + 1]) %= mod;
}
g[i] = f[i];
for(int t = i - 1; t >= 0; --t)
(g[i] -= g[t] * coef[cur][t]) %= mod;
g[i] = fsp(fac, mod - 2, g[i]);
}
printf("%lld\n", (g[k] + mod) % mod);
}