题目描述
给定一张 \(n\times m\) 的网格,求这张网格上最多能放多少个炮兵。
网格上有一些地方不能放炮;炮兵的攻击范围是前、后、左、右各 \(2\) 格。
\(1 \leq n \leq 100 ,1 \leq m \leq 10\) 。
Solution
发现 \(m\) 很小,且第 \(i\) 行的状态之和 \(i-1,i-2\) 行有关,考虑状压 \(dp\) 。
设 \(f_{i,j,k}\) 表示前 \(i\) 行中,第 \(i\) 行的状态是 \(j\) ,第 \(i - 1\) 行的状态是 \(k\) 的方案中,最多能放多少个炮兵。
设 \(vaild(i,j)\) 表示 \(i,j\) 这两个状态能不能放在相邻两行, \(check(i,jj)\) 表示把状态 \(j\) 放到第 \(i\) 行是否合法, \(count(j)\) 表示状态 \(j\) 中放置了几个炮兵(二进制中 \(1\) 的个数)。
如果 \(check(i,j) = false\) 或者 ,\(check(i-1,k)=false\) ,那么 \(f_{i,j,k} = 0\) 。
考虑怎么计算状态,我们可以枚举第 \(i-2\) 行的状态 \(p\) ,则:
\[f_{i,j,k}=\begin{cases} \max\{f_{i-1,k,p}+count(j)\}&&vaild(j,k) \and vaild(j,p) \and vaild(k,p) \\ 0 && vaild(j,k)=false \end{cases} \]时间复杂度 \(\mathcal{O}(n2^m2^m2^m)=\mathcal{O}(n8^m)\) ,空间复杂度 \(\mathcal{O}(n4^m)\)
考虑优化:同一行内合法状态很少,所以我们可以先预处理出所有合法状态(最多 \(60\) 个),就可以把时间复杂度降到 \(\mathcal{O}(nc^3)\) ,空间复杂度降到 \(\mathcal{O(nc^2)}\) ,其中 \(c\) 表示合法状态数。
代码如下:
#include <cstdio>
#include <cstring>
#include <cctype>
inline int max(int a ,int b) {return a > b ? a : b;}
inline int lowbit(int x) {return x & (-x);}
inline int count(int x) {
int ans = 0;
while (x) ans++ ,x -= lowbit(x);
return ans;
}
inline bool check(int x) {
return ((x & (x >> 1)) == 0) && ((x & (x >> 2)) == 0);
}
inline bool check(int x ,int y) {return (x & y) == 0;}
inline bool check(int x ,int y ,int z) {
return check(x ,y) && check(x ,z) && check(y ,z);
}
const int N = 105 ,M = 15 ,S = 75;
int f[N][S][S] ,n ,m ,s[S] ,bin[S] ,idx ,v[N]; char str[M];
inline bool vaild(int i ,int j) {
return (v[i] & s[j]) == s[j];
}
signed main() {
scanf("%d%d" ,&n ,&m);
for (int i = 1; i <= n; i++) {
scanf("%s" ,str + 1);
for (int j = 1; j <= m; j++)
v[i] = v[i] << 1 | (str[j] == 'P');
}
for (int i = 0; i < (1 << m); i++)
if (check(i)) s[++idx] = i ,bin[idx] = count(i);
for (int i = 1; i <= idx; i++)
for (int j = 1; j <= idx; j++)
if (vaild(1 ,j) && vaild(2 ,i) && check(s[i] ,s[j]))
f[2][i][j] = bin[i] + bin[j];
//先预处理出第 2 行的信息,不然会越界造成错误。
//不用预处理第 1 行是因为这里没有用到。
for (int i = 3; i <= n; i++)
for (int j = 1; j <= idx; j++)
if (vaild(i ,j))
for (int k = 1; k <= idx; k++)
if (vaild(i - 1, k) && check(s[j] ,s[k]))
for (int p = 1; p <= idx; p++)
if (vaild(i - 2 ,p) && check(s[j] ,s[k] ,s[p]))
f[i][j][k] = max(f[i][j][k] ,f[i - 1][k][p] + bin[j]);
int ans = 0;
for (int i = 1; i <= idx; i++)
for (int j = 1; j <= idx; j++)
ans = max(ans ,f[n][i][j]);
printf("%d\n" ,ans);
return 0;
}