题面
题解
假设每个位置的值已经确定,为 \(a_i\),那么将 \((a_i, i)\) 二元组排序的方法唯一。
枚举最后的排序结果 \(p\),\(p_i\) 表示排序之后在排名为 \(i\) 的下标为 \(p_i\)。
那么可以反推出 \(a_{p_1} (< \mathrm{or} \leq) \ a_{p_2} (< \mathrm{or} \leq) \ \cdots \ (< \mathrm{or} \leq) \ a_{p_n}\)。
其中如果 \(p_i > p_{i + 1}\),那么 \(a_{p_i} \leq a_{p_{i + 1}}\) 否则是 \(<\)。
考虑将 \(<\) 变成 \(\leq\):若 \(a_{p_i}\) 和 \(a_{p_{i + 1}}\) 之间的限制为 \(<\),那么将所有 \(a_{p_j} (j > i)\) 的值全部减少 \(1\),那么这个位置的限制就变成了 \(\leq\),也就是说,\(p_{j} (j > i)\) 的所有位置的限制都减小了 \(1\)。
接下来就可以各显神通了:可以用这题的方法 dp,也可以用下面介绍的方法(orz Itst):
将限制放到格路上,那么相当于从 \((1, 1)\) 走到 \((n + 1, \infty)\),只能向上向右走,当横坐标为 \(i\) 时纵坐标必须 \(\leq A_i\)(也就是之前的限制)的方案数,可以容斥计算。这样,如果不将计算组合数的时间算在复杂度里面的话,时间复杂度为 \(\mathcal O(n^2)\)。
综上,时间复杂度为 \(\mathcal O(n!n^3)\),比 std 好像要优秀一些。
代码
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
inline int read()
{
int data = 0, w = 1; char ch = getchar();
while (ch != '-' && (ch < '0' || ch > '9')) ch = getchar();
if (ch == '-') w = -1, ch = getchar();
while (ch >= '0' && ch <= '9') data = data * 10 + (ch ^ 48), ch = getchar();
return data * w;
}
const int N(10), Mod(1e9 + 7);
int n, a[N], id[N], h[N], ans, f[N];
int fastpow(int x, int y)
{
int ans = 1;
for (; y; y >>= 1, x = 1ll * x * x % Mod)
if (y & 1) ans = 1ll * ans * x % Mod;
return ans;
}
int C(int n, int m)
{
int s = 1;
for (int i = 1; i <= m; i++) s = 1ll * s * (n - i + 1) % Mod * fastpow(i, Mod - 2) % Mod;
return s;
}
int main()
{
n = read();
for (int i = 1; i <= n; i++) a[i] = read(), id[i] = i;
do
{
for (int i = 1; i <= n; i++) h[i] = a[id[i]] - 1;
for (int i = 1; i <= n; i++) if (id[i] < id[i + 1])
for (int j = i + 1; j <= n; j++) --h[j];
for (int i = n - 1; i; i--) h[i] = std::min(h[i], h[i + 1]);
std::memset(f, 0, sizeof f);
for (int i = 1; i <= n; i++)
{
f[i] = C(h[i] + i - 1, i - 1);
for (int j = 1; j < i; j++)
f[i] = (f[i] - 1ll * f[j] * C(h[i] - h[j] - 1 + i - j, i - j) % Mod + Mod) % Mod;
}
int res = C(h[n] + n, n), mx = 0;
for (int i = 1; i <= n; i++)
res = (res - 1ll * f[i] * C(h[n] - h[i] + n - i, n - i + 1) % Mod + Mod) % Mod;
std::memset(f, 0, sizeof f);
for (int i = 1; i <= n; mx = std::max(mx, f[i++]))
for (int j = f[i] = 1; j < i; j++)
if (id[j] < id[i]) f[i] = std::max(f[i], f[j] + 1);
ans = (ans + 1ll * res * mx) % Mod;
} while (std::next_permutation(id + 1, id + n + 1));
for (int i = 1; i <= n; i++) ans = 1ll * ans * fastpow(a[i], Mod - 2) % Mod;
printf("%d\n", ans);
return 0;
}