题目大意
给你一个n*m的矩阵,要你求这个矩阵内的左右上下都对称的子正方形共有多少个。
解
把矩阵左右、上下镜面反过来,处理矩阵的哈希值。对比翻折前后的矩阵哈希值是否相等。
暴力枚举中心点,二分枚举矩阵大小。
(有单调性因为,当一个矩阵符合条件,以它的中心为中心的比他小的矩阵也符合条件)
矩阵中心有可能不在某个点上,所以跑两次二分。
代码
#include<cstdio>
#include<iostream>
#define h1 131
#define h2 107
#define ull unsigned long long
ull hash[1011][1011], hash_h[1011][1011], hash_l[1011][1011], t1[1011], t2[1011], ans, js, k;
int n, m, a[1011][1011], b[1011][1011], c[1011][1011];
bool check(int xx1, int yy1, int xx2, int yy2){
if(xx1 < 1 || yy1 < 1 || xx2 > n || yy2 > m) return 0;
int x1, x2, y1, y2, z1, z2, z3;
x1 = xx1; x2 = xx2; y1 = yy1; y2 = yy2;//坐标
z1 = hash[x2][y2] - hash[x1-1][y2] * t2[x2-x1+1] - hash[x2][y1-1] * t1[y2-y1+1]
+ hash[x1-1][y1-1] * t1[y2-y1+1] * t2[x2-x1+1]; //原矩阵
x1 = xx1; x2 = xx2; y1 = m - yy2 + 1; y2 = m - yy1 + 1;//坐标
z2 = hash_h[x2][y2] - hash_h[x1-1][y2] * t2[x2-x1+1] - hash_h[x2][y1-1] * t1[y2-y1+1]
+ hash_h[x1-1][y1-1] * t1[y2-y1+1] * t2[x2-x1+1]; //左右镜面
x1 = n-xx2+1; x2 = n-xx1+1; y1 = yy1; y2 = yy2; //坐标
z3 = hash_l[x2][y2] - hash_l[x1-1][y2] * t2[x2-x1+1] - hash_l[x2][y1-1] * t1[y2-y1+1]
+ hash_l[x1-1][y1-1] * t1[y2-y1+1] * t2[x2-x1+1]; //上下镜面
if(z1 == z2 && z1 == z3) return 1; //回文、对称
return 0;
}
int main(){
scanf("%d%d", &n, &m);
t1[0] = t2[0] = 1;
for(int i = 1; i <= 1000; ++i){ //预处理
t1[i] = t1[i-1] * h1;
t2[i] = t2[i-1] * h2;
}
for(int i = 1; i <= n; ++i)
for(int j = 1; j <= m; ++j)
scanf("%d", &a[i][j]);
for(int i = 1; i <= n; ++i)
for(int j = 1; j <= m; ++j){
b[i][j] = a[i][m-j+1]; //左右镜面
c[i][j] = a[n-i+1][j]; //上下镜面
}
for(int i = 1; i <= n; ++i) //得出hash值
for(int j = 1; j <= m; ++j) {
hash[i][j] = hash[i][j-1] * h1 + a[i][j];
hash_h[i][j] = hash_h[i][j-1] * h1 + b[i][j];
hash_l[i][j] = hash_l[i][j-1] * h1 + c[i][j];
}
for(int i = 1; i <= n; ++i)
for(int j = 1; j <= m; ++j) {
hash[i][j] += hash[i-1][j] * h2;
hash_h[i][j] += hash_h[i-1][j] * h2;
hash_l[i][j] += hash_l[i-1][j] * h2;
}
ans = n * m;
k = n<m?m:n;
for(int i = 1; i <= n; ++i) //枚举中心
for(int j = 1; j <= m; ++j) {
int l = 1, r = k; js = 0;
while(l <= r){ //以它为中心
int mid = (l+r) >> 1;
if(check(i-mid, j-mid, i+mid, j+mid) == 1) js = mid, l = mid + 1;
else r = mid - 1;
}
ans += js;
l = 1; r = k; js = 0;
while(l <= r){ //它在偏左上一点
int mid = (l+r) >> 1;
if(check(i-mid+1, j-mid+1, i+mid, j+mid) == 1) js = mid, l = mid + 1;
else r = mid - 1;
}
ans += js;
}
printf("%lld", ans);
}