<div class="post_brief"><p>
我的第二道插头DP。差不多纠结了一天。</p>
这道题网上的题解都讲的好简略啊。最后还得自己yy。没有用括号序列法让我觉得自己比较厉害。
考虑轮廓线上的m个格子,有3种情况:没有插头,有且只能往一边走,有且要往两边走。对于第二种,要用最多3个不同的值来记录连通性。还要有两个不同的值来表示第一种情况和第三种情况。我比较懒,所以直接用一个八进制数去表示,然后每次二分找编号。
考虑转移,分一堆情况进行讨论。默认从左上朝右下走。
首先是要走当前格子的情况。
如果左边和上面都有插头,那么看它们是否连通。如果不连通则把它们连通,如果已经连通,看是否还有其它插头。没有就更新答案,否则忽略。
如果只有左边有插头,再分类。如果它是第二种,那么这个格子可以新建一个第三种插头,或者把左边格子的插头接过来。如果它是第三种那么它必需把左边格子接过来。
如果只有上面有插头,那么必需把上面的插头接上来。
如果左边和上面都没有插头,那么可以新建一个第三种插头。
然后如果不走这个格子的话,要求上面没有插头,且左边不是第三种插头。
按mhy的话说,插头dp就是写起来很麻烦。的确。不过写出来也还是挺高兴的。
然后代码丑到不能看了。
#include <cstdio> #include <cstring> #include <algorithm>using namespace std;
const int maxn = 109; const int maxm = 9; const int maxst = 50009; const int inf = 0x3f3f3f3f;
#define pow8(x) (1<<(x)<<(x)<<(x)) #define mbit(x,y) ((x)<<(y)<<(y)<<(y)) #define gbit(x,y) (((x)>>(y)>>(y)>>(y))&7)
int f[2][maxst], n, m, a[maxn][maxm], ans; int slst[maxst], tots;
inline void upMax(int& a, int b) { if (a < b) a = b; }
void dfsState(int cur, int tot, int val, int e) { if (cur == m) { static int cnt[maxm]; memset(cnt, 0, sizeof(cnt)); for (int i = 0; i < m; ++ i) ++ cnt[gbit(val, i)]; bool lg(1); for (int i = 1; i <= tot && lg; ++ i) if (cnt[i] != 0 && cnt[i] != 2) lg = 0; if (lg) slst[tots ++] = val; } else { for (int i = 0; i <= tot; ++ i) dfsState(cur + 1, tot, val | mbit(i, cur), e); dfsState(cur + 1, tot + 1, val | mbit(tot + 1, cur), e); if (e) dfsState(cur + 1, tot, val | mbit(7, cur), 0); } } void preState() { tots = 0; dfsState(0, 0, 0, 1); sort(slst, slst + tots); }
int fState(int s) { int ret(lower_bound(slst, slst + tots, s) - slst); if (slst[ret] != s) puts(“naive”); return ret; }
int joinState(int s, int p) { static int z[maxm], y[maxn], c; for (int i = 0; i < m; ++ i) z[i] = gbit(s, i); if (z[p - 1] == 7) { z[p - 1] = z[p]; z[p] = 0; } else { int f(z[p]), t(z[p - 1]); for (int i = 0; i < m; ++ i) if (z[i] == f) z[i] = t; z[p - 1] = 0; z[p] = 0; memset(y, 0, sizeof(y)); c = 0; for (int i = 0; i < m; ++ i) if (z[i] && z[i] < 7) { if (y[z[i]]) { z[i] = y[z[i]]; } else { y[z[i]] = ++ c; z[i] = y[z[i]]; } } } int f(0); for (int i = 0; i < m; ++ i) f |= mbit(z[i], i); return fState(f); }
int expState(int s, int p) { static int z[maxm]; for (int i = 0; i < m; ++ i) z[i] = gbit(s, i); int f(1); for (int i = 0; i < p - 1; ++ i) if (z[i] < 7 && z[i] + 1 > f) f = z[i] + 1; for (int i = 0; i < m; ++ i) if (z[i] < 7 && z[i] >= f) ++ z[i]; z[p - 1] = f; z[p] = f; f = 0; for (int i = 0; i < m; ++ i) f |= mbit(z[i], i); return fState(f); } int extState(int s, int p) { static int z[maxm]; for (int i = 0; i < m; ++ i) z[i] = gbit(s, i); z[p] = z[p - 1]; z[p - 1] = 0; int f(0); for (int i = 0; i < m; ++ i) f |= mbit(z[i], i); return fState(f); }
int forkState(int s, int p) { return fState(s | mbit(7, p)); }
bool legalState(int s, int p) { static int z[maxm]; for (int i = 0; i < m; ++ i) z[i] = gbit(s, i); for (int i = 0; i < p - 1; ++ i) if (z[i]) return 0; for (int i = p + 1; i < m; ++ i) if (z[i]) return 0; return 1; }
void fillArr(int* a, int sz, int v) { for (int i = 0; i < sz; ++ i) a[i] = v; }
void dp() { int cur(0), prv(1); fillArr(f[cur], tots, -inf); f[cur][0] = 0; for (int x = 0; x < n; ++ x) { for (int y = 0; y < m; ++ y) { swap(cur, prv); fillArr(f[cur], tots, -inf); f[cur][0] = 0; for (int i = 0; i < tots; ++ i) if (f[prv][i] > -inf) { int s(slst[i]); if (gbit(s, y) && y && gbit(s, y - 1)) { if (gbit(s, y) == gbit(s, y - 1)) { if (legalState(s, y)) upMax(ans, f[prv][i] + a[x][y]); } else { upMax(f[cur][joinState(s, y)], f[prv][i] + a[x][y]); } } if (!gbit(s, y) && y && gbit(s, y - 1)) { if (gbit(s, y - 1) == 7) upMax(f[cur][expState(s, y)], f[prv][i] + a[x][y]); else upMax(f[cur][extState(s, y)], f[prv][i] + a[x][y]); } if (gbit(s, y) && (!y || gbit(s, y - 1) != 7)) upMax(f[cur][i], f[prv][i] + a[x][y]); if (y < m - 1 && !gbit(s, y) && (!y || gbit(s, y - 1) != 7)) upMax(f[cur][forkState(s, y)], f[prv][i] + a[x][y]); if (!gbit(s, y) && (!y || gbit(s, y - 1) != 7)) upMax(f[cur][i], f[prv][i]); } } } }
int main() { #ifndef ONLINE_JUDGE freopen(“in.txt”, “r”, stdin); #endif
scanf("%d%d", &n, &m); preState(); ans = -inf; for (int i = 0; i < n; ++ i) for (int j = 0; j < m; ++ j) scanf("%d", a[i] + j); ans = -inf; dp(); printf("%d\n", ans);
}