OTL写manacher的mhy。
我没有那么高端去研究st怎么搞,所以就用了一个二维hash。从四个角各来一遍真是爽翻了。需要注意的是hash的时候行之间和列之间的值要不同才能区别,不然有数据随便可以卡掉。
代码么,巨丑无比。
#include <cstdio>
#include <cctype>
#include <memory.h>
#include <algorithm>
using namespace std;
#define uint _uint_
typedef unsigned long long int uint;
const int maxn = 1009;
const uint bh = 1e9 + 93;
const uint bv = 1e9 + 7;
#define readInt(_s_) {\
int _d_;\
_s_ = 0;\
while (!isdigit(_d_ = getchar()));\
while ((_s_ = _s_ * 10 + _d_ - 48), isdigit(_d_ = getchar()));\
}
int n, m, a[maxn][maxn];
uint h[4][maxn][maxn], ph[maxn * 2], pv[maxn * 2];
void preHash() {
ph[0] = 1;
pv[0] = 1;
for (int i = 1; i < maxn * 2; ++ i)
ph[i] = ph[i - 1] * bh, pv[i] = pv[i - 1] * bv;
memset(h, 0, sizeof(h));
for (int ti = 0; ti < 4; ++ ti) {
int x0, x1, xd;
if (ti & 2) {
x0 = n;
x1 = 1;
xd = -1;
}
else {
x0 = 1;
x1 = n;
xd = 1;
}
int y0, y1, yd;
if (ti & 1) {
y0 = m;
y1 = 1;
yd = -1;
}
else {
y0 = 1;
y1 = m;
yd = 1;
}
for (int x = x0; x != x1 + xd; x += xd)
for (int y = y0; y != y1 + yd; y += yd)
h[ti][x][y] = h[ti][x - xd][y] * bh + h[ti][x][y - yd] * bv\
- h[ti][x - xd][y - yd] * bv * bh + (uint)a[x][y];
}
}
inline int getRange(int d, int x0, int y0, int x1, int y1) {
return h[d][x1][y1] - h[d][x0][y1] * ph[abs(x1 - x0)] - h[d][x1][y0] * pv[abs(y1 - y0)]\
+ h[d][x0][y0] * ph[abs(x1 - x0)] * pv[abs(y1 - y0)];
}
int calc(int tx, int ty) {
int x0 = (tx >> 1), x1 = tx - x0;
int y0 = (ty >> 1), y1 = ty - y0;
if (x0 ^ x1)
if (a[x0][y0] != a[x0][y1] || a[x0][y0] != a[x1][y0] || a[x0][y0] != a[x1][y1])
return 0;
int l = 1, r = min(min(x0, n - x1 + 1), min(y0, m - y1 + 1));
while (l < r) {
int mid = (l + r + 1) >> 1;
int h0 = getRange(0, x0 - mid, y0 - mid, x0, y0);
int h1 = getRange(1, x0 - mid, y1 + mid, x0, y1);
int h2 = getRange(2, x1 + mid, y0 - mid, x1, y0);
int h3 = getRange(3, x1 + mid, y1 + mid, x1, y1);
if (h0 != h1 || h1 != h2 || h2 != h3)
r = mid - 1;
else
l = mid;
}
//if (l > 1 || (x0 != x1 && l))
//printf("%d %d %d\n", tx, ty, l);
return l;
}
int main() {
#ifndef ONLINE_JUDGE
freopen("in.txt", "r", stdin);
#endif
readInt(n);
readInt(m);
for (int i = 1; i <= n; ++ i)
for (int j = 1; j <= m; ++ j)
readInt(a[i][j]);
preHash();
int ans = 0;
for (int i = 2; i <= n * 2; ++ i)
for (int j = 2; j <= m * 2; ++ j)
if ((i ^ j ^ 1) & 1)
ans += calc(i, j);
printf("%d\n", ans);
}