对于一个长度为 \(n\) 的序列 \(a\) 做插入排序,即依次考虑 \(a_2, \cdots, a_n\) 每个元素:
-
如果 \(a_i \ge a_{i - 1}\),即前缀依然保持不下降,不做操作。
-
否则,找到最靠前的位置 \(p\),使得 \(a_i < a_p\),然后将 \(a_i\) 插入到 \(a_p\) 前面,并重新标号序列。记这次插入为 \((i, p)\)。
接下来给定 \(m\) 个二元组 \((x_i, y_i)\),求有多少个长度为 \(n\) 的序列 \(a\),满足 \(\forall i, a_i \in [1, n] \cap \mathbb N\) 且做插入排序恰好进行给定的 \(m\) 次插入。
输出序列的数量对 \(998244353\) 取模的结果。
\(2 \le n \le 2 \cdot 10^5\),\(0 \le m < n\),\(2 \le x_1 < x_2 < \cdots \le n\),\(1 \le y_i < x_i\),\(\sum n \le 2 \cdot 10^5\),\(\sum m\) 没有限制。
3s, 512MB
当 \(m\) 次插入确定了以后,本质上是一个从原序列到新序列的置换,例如 \(n = 3\),只进行了 \((2, 1)\) 这次插入,那么相当于是从原先的 \([a_1, a_2, a_3]\) 变成了 \([a_2, a_1, a_3]\)。
那么,根据最终序列是不下降的,可以得到一个 \(a_2 \le a_1 \le a_3\) 的不等式。
这个条件显然不是充要的,原因在于,当发生插入 \((x, y)\) 时,确定了一组明确的严格小于的关系(被插入的数 \(<\) 原来该位置上的数,注意不是 \(a_x < a_y\)),而非简单的不下降。
在上面的例子中,发生 \((2, 1)\) 这次插入表示 \(a_2 < a_1\)。所以真正严格的限制应当是 \(a_2 < a_1 \le a_3\)。
先不妨不关心最终的限制是什么,假设我们知道了最终限制中有 \(c\) 个「小于号」和 \(n - 1 - c\) 个「小于等于号」,根据隔板法,可以知道方案数应当是 \(\binom{2n - 1 - c}{n}\)。
问题转化为了求出 \(c\) 的值,即小于号的个数。
根据上面的过程,可以发现,一个数前面是小于号而不是小于等于号的要求是「有另一个数直接插入到了它的前面」。
考虑逆序这些插入,维护一个集合 \(S\)。初始时 \(S\) 包含了 \(1,2, \cdots, n\) 的所有的数。
每次遇到一个 \((x_i, y_i)\) 时,查询 \(S\) 中第 \(y_i\) 小的数 \(p\) 和第 \(y_i + 1\) 小的数 \(q\),即将 \(p\) 插入到了 \(q\) 之前。那么从 \(S\) 中删除 \(p\),令 \(tag_q \gets 1\),表示有个数插入到了 \(q\) 之前。
最终 \(c\) 就等于 $tag= 1 $ 的位置的个数。
注意实现应当是 \(O(m \log n)\),不要搞成了 \(O(n \log n)\)。
代码用线段树二分实现的。
#include <bits/stdc++.h>
typedef long long ll;
const int N = 2e5 + 5, MOD = 998244353;
int n, m, x[N], y[N], roll[N], val[N << 2], fac[N << 1], ifac[N << 1];
std::set<int> pos;
int qpow(int a, int p) {
int res = 1;
while(p) {
if(p & 1) res = (ll)res * a % MOD;
a = (ll)a * a % MOD; p >>= 1;
}
return res;
}
void init(int l, int r, int rt) {
val[rt] = r - l + 1;
if(l == r) return;
int mid = (l + r) >> 1;
init(l, mid, rt << 1); init(mid + 1, r, rt << 1 | 1);
}
int query(int l, int r, int rt, int k) {
if(l == r) return l;
int mid = (l + r) >> 1;
if(val[rt << 1] >= k) return query(l, mid, rt << 1, k);
else return query(mid + 1, r, rt << 1 | 1, k - val[rt << 1]);
}
void modify(int l, int r, int rt, int p, int v) {
val[rt] += v;
if(l == r) return;
int mid = (l + r) >> 1;
if(p <= mid) modify(l, mid, rt << 1, p, v);
else modify(mid + 1, r, rt << 1 | 1, p, v);
}
int C(int a, int b) {
if(a < 0 || b < 0 || a < b) return 0;
return (ll)fac[a] * ifac[b] % MOD * ifac[a - b] % MOD;
}
int main() {
fac[0] = 1;
for(int i = 1; i < N * 2; i++) fac[i] = (ll)fac[i - 1] * i % MOD;
ifac[N * 2 - 1] = qpow(fac[N * 2 - 1], MOD - 2);
for(int i = N * 2 - 1; i; i--) ifac[i - 1] = (ll)ifac[i] * i % MOD;
int test; scanf("%d", &test); init(1, N - 1, 1);
while(test--) {
scanf("%d %d", &n, &m); pos.clear();
for(int i = 1; i <= m; i++) scanf("%d %d", &x[i], &y[i]);
for(int i = m; i; i--) {
int p = query(1, N - 1, 1, y[i]), q = query(1, N - 1, 1, y[i] + 1);
modify(1, N - 1, 1, p, -1);
roll[i] = p; pos.insert(q);
}
for(int i = 1; i <= m; i++) modify(1, N - 1, 1, roll[i], 1);
int c = (int)pos.size();
printf("%d\n", C(n * 2 - 1 - c, n));
}
return 0;
}