statement
- 给定一个 \(n \times n\) 的表格,表格中的每个元素有 \(p_{i,j}\) 的概率为 \(1\),否则为 \(0\)。
- 求至少有一行或一列或一条对角线全为 \(1\) 的概率。对角线指主对角线或副对角线。
- \(n \le 21\)
solution 1
首先显然是补集转化,求行、列、对角线全部存在为一个为 \(0\) 的元素的概率。
有一种 naive 的思路,直接设 \(dp(mask)\) 表示在状态 \(mask\) 下的概率。
\(mask\) 的总量达到了惊人的 \(2 ^ {44}\) 。。。
因为 \(21 \times 2 ^ {21}\) 还是可以接受的,所以可以枚举每一行,记录每一列的状态。
这里状态的第 \(i\) 位表示在当前枚举的行中,这一列是否为 \(0\) 。
考虑如何合并答案,设刚刚这个求出的答案为 \(f(i)\) ,枚举到当前点的答案为 \(g(i)\) ,那么有:
不难发现这就是一个 or
卷积,可以用 \(\rm FWT\) 求出。
但是这样并没有考虑到对角线的问题,对于第 \(i\) 行,主对角线的元素位置在第 \(i\) 列,副对角线的位置在第 \(n - i + 1\) 列,然后发现这和每一列并没有什么本质区别,也很简单地解决了。
时间复杂度 \(\mathcal O (n ^ 2 \times 2 ^ n)\) 。
int n, mkp[1 << 23]; Mint zr[21][21], iv[21][21], f[1 << 23], g[1 << 23];
int main() {
Mint inv = q_pow(Mint(10000));
scanf("%d", &n);
rep(i,0,n) rep(j,0,n) {
static int x;
scanf("%d", &x);
zr[i][j] = Mint(1) - inv * Mint(x);
iv[i][j] = q_pow(inv * Mint(x));
}
int mask = 1 << n + 2;
rep(i,0,mask) g[i] = Mint(1);
rep(i,0,n) mkp[1 << i] = i;
rep(i,0,n) {
f[0] = Mint(1);
rep(j, 0, n) f[0] *= Mint(1) - zr[i][j];
rep(S, 1, 1 << n) f[S] = f[S ^ (S & -S)] * zr[i][mkp[S & -S]] * iv[i][mkp[S & -S]];
rep(S, 1 << n, mask) f[S] = Mint(0);
rep(S, 1, 1 << n)
if((S >> i & 1) && (S >> (n - i - 1) & 1)) f[S ^ (1 << n) ^ (1 << n + 1)] = f[S], f[S] = Mint(0);
else if(S >> i & 1) f[S ^ (1 << n)] = f[S], f[S] = Mint(0);
else if(S >> (n - i - 1) & 1) f[S ^ (1 << n + 1)] = f[S], f[S] = Mint(0);
f[0] = Mint(0);
DWT(f, mask);
rep(S, 0, mask) g[S] = g[S] * f[S];
}
IDWT(g, mask);
printf("%d\n", (Mint(1) - g[mask - 1]).res);
return 0;
}
最大点 6972ms
solution 2
然后是 \(\color{black}{\rm t}\color{red}{\rm ourist}\) 的官方解法。。。
直接容斥。。。
设 \(dp(S = \{a_1, a_2, ..., a_k\})\) 表示所有 \(a_i\) 列均满足条件的概率,这里的条件指的是所有列存在一个为 \(0\) 。
和上面的思路差不多,对于每一个新加进来的一行,设为 \(i\) ,都有一个转移:
其中 \(g(i, S)\) 表示第 \(i\) 行,在集合 \(S\) 中的列全部是 \(1\) 的概率。
对这个 \(f\) 加个容斥系数就好了。
然后,预处理 \(g\) 和 \(f\) 的时间复杂度都是 \(\mathcal O (n \times 2 ^ n)\) 的。
int n; Mint g[21][1 << 23], f[1 << 23];
int main() {
Mint inv = q_pow(Mint(10000));
scanf("%d", &n);
rep(i,0,n) {
g[i][0] = Mint(1);
rep(j,0,n) {
static int x;
scanf("%d", &x);
g[i][1 << j] = inv * Mint(x);
}
rep(S,1,1 << n) g[i][S] = g[i][S ^ (S & -S)] * g[i][S & -S];
}
int mask = 1 << n; Mint res = Mint(0);
rep(dia,0,4) {
rep(S,0,mask) if((__builtin_popcount(dia) + __builtin_popcount(S)) & 1) f[S] = Mint(Mod - 1);
else f[S] = Mint(1);
rep(i,0,n) rep(S,0,mask) f[S] *= g[i][S | ((dia & 1) << i) | (((dia & 2) == 2) << (n - i - 1))] - g[i][mask - 1];
rep(S,0,mask) res += f[S];
}
printf("%d\n", (Mint(1) - res).res);
return 0;
}