MARK on 2022.1.3:由于本人觉得“组合数学杂题选做”这篇博客太累赘了,故将其删除并将其中所有题解都单独开一篇博客写入。
组合数学好题。
Key observation:对于一个字符串 \(s\),其最长的“弱子串”的长度等于 \(|s|\) 减去 \(\min(s\text{最长的字符都不同的前缀},s\text{最长的字符都不同的后缀})\)。
证明:我们先来证一个引理:如果一个子串是 \(s\) 的弱子串,那么其对应的符合要求的子序列必然可以通过在子串的基础上移动一个字符得到。该定理的证明大概就,考虑一个子序列 \(i_1,i_2,\cdots,i_k(i_t<i_{t+1})\) 和一个子串 \(j_1,j_2,\cdots,j_k(j_t<j_{t+1})\),如果它们公共的元素个数 \(<k-1\),那么我们考虑最大的 \(t\) 满足 \(i_t<j_t\),由于 \(t\) 是符合要求的下标中最大的,必然有 \(i_{t+1}\ge j_{t+1}\),因此 \(j_t\sim i_{t+1}\) 之间没有任何被选择的元素,我们完全可以直接将 \(i_t\) 改为 \(j_t\),这样重合部分大小会 \(+1\),又由于原本重合部分大小 \(<k-1\),因此改完之后子序列还是符合弱子序列的要求,故我们的修改是符合要求的。如此进行下去直至重合部分达到 \(k-1\) 即可。
于是我们不妨考虑一个子串,由于我们只能移动一个元素,因此我们只能移动第一个或者最后一个元素,如果我们移动第一个元素,那么我们找到 \(s\) 最长的字符都不同的前缀长度 \(len\),并令该子串为 \(s[len+1…|s|]\),容易证明这样操作是所有只移动第一个元素的情况中最优的,对于移动最后一个元素的情况也是类似的,两者取个 \(\max\) 就是 \(|s|-\min(s\text{最长的字符都不同的前缀},s\text{最长的字符都不同的后缀})\)。
接下来考虑如何计算答案。考虑容斥,拿弱子串长度 \(\le w\) 的方案数减去弱子串长度 \(\le w-1\) 的方案数即可得到答案,于是问题转化为求最长弱子串长度 \(\le w\) 的方案数。根据上面的式子可以看出,长度 \(\le w\) 的长度都符合要求,方案数 \(\sum\limits_{i=1}^wk^i\),一波等比数列求和/分治带走,对于长度为 \(w+i(i\in[1,k])\) 的情况,字符串符合要求当且仅当前 \(i\) 位互不相同,后 \(i\) 位也互不相同,继续分情况:
- 如果 \(2i\le w+i\),那么前 \(i\) 位与后 \(i\) 位是独立的,方案数 \((k^{\underline{i}})^2·k^{w-i}\)。
- 否则记 \(c=2i-(w+i)\),那么中间方案数 \(k^{\underline{c}}\),两侧方案数都是 \((k-c)^{\underline{i-c}}\),总方案数 \(k^{\underline{c}}·((k-c)^{\underline{i-c}})^2\)
而对于长度 \(>w+k\) 的字符串,由于字符集大小为 \(k\),显然不符合要求。
直接计算即可,时间复杂度 \(k\log w\)。
const int MAXN = 1e6;
const int MOD = 1e9 + 7;
int k, w, fac[MAXN + 5], ifac[MAXN + 5];
int qpow(int x, int e) {
int ret = 1;
for (; e; e >>= 1, x = 1ll * x * x % MOD)
if (e & 1) ret = 1ll * ret * x % MOD;
return ret;
}
void init_fac(int n) {
for (int i = (fac[0] = ifac[0] = ifac[1] = 1) + 1; i <= n; i++)
ifac[i] = 1ll * ifac[MOD % i] * (MOD - MOD / i) % MOD;
for (int i = 1; i <= n; i++) {
fac[i] = 1ll * fac[i - 1] * i % MOD;
ifac[i] = 1ll * ifac[i - 1] * ifac[i] % MOD;
}
}
int calc_spw(int t, int k) {
if (!t) return 0; if (t == 1) return k;
int mid = t >> 1, sum = calc_spw(mid, k);
sum = (sum + 1ll * sum * qpow(k, mid)) % MOD;
if (t & 1) sum = (sum + qpow(k, t)) % MOD;
return sum;
}
int calc_A(int n, int k) {
if (n < k) return 0;
return 1ll * fac[n] * ifac[n - k] % MOD;
}
int calc(int k, int w) {
int ss = calc_spw(w, k);
for (int i = 1; i <= k; i++) {
int len = i + w;
if (i + i <= len) ss = (ss + 1ll * calc_A(k, i) * calc_A(k, i) % MOD * qpow(k, len - i - i)) % MOD;
else {
int cap = i + i - len;
ss = (ss + 1ll * calc_A(k, cap) * calc_A(k - cap, i - cap) % MOD * calc_A(k - cap, i - cap)) % MOD;
}
}
// printf("%d %d %d\n", k, w, ss);
return ss;
}
int main() {
scanf("%d%d", &k, &w); init_fac(MAXN);
printf("%d\n", (calc(k, w) - calc(k, w - 1) + MOD) % MOD);
return 0;
}