orz qyc
看很多写法都是 \(\mathcal{O}(n^4)\) 的,其实稍微预处理下就能做到 \(\mathcal{O}(n^3)\) 的了。
并不那么显然地看出是个区间 dp,后面就很好做了。
发现平方聚在一起是更优的,则区间 dp 应该是枚举一列让它尽可能的多选。
基于这个贪心的思路,设 \(f_{l,r}\) 为仅考虑左右端点均在 \([l,r]\) 内的段的答案,转移就枚举哪一行是尽可能填的。设 \(s_{k,l,r}\) 为代表左右端点都在 \([l,r]\) 内且经过 \(k\) 列的区间个数。
枚举尽可能多选的那一列是第 \(k\) 列,则有:
\[f_{l,r}=\max\{f_{l,k-1}+{s_{k,l,r}}^2+f_{k+1,r}\} \]然后考虑怎样算 \(s\)。
对于给出的一个段 \(x,y\),实际上是对所有的 \(l\leq x,r\geq y,x\leq k\leq y\),\(s_{k,l,r}\) 都加上 \(1\)。
也许读者注意到一个细节,\(s\) 的状态设计中,我将 \(k\) 放在第一维,实际上是方便对 \(\mathcal{O}(n^3)\) 预处理 \(s\) 作铺垫。考虑对每个 \(k\) 建立一个平面直角坐标系,横坐标为 \(l\),纵坐标为 \(r\),\((l,r)\) 上的值为 \(s_{k,l,r}\)。
那么所有满足 \(l\leq x,r\geq y\) 的 \((l,r)\),构成了一个左上角上的矩形,矩形加最后单点查,可以二维差分解决。
但实际上都是左上角的矩形,在 \((l,r)\) 打个 tag 然后做二维前缀和就完全足够。
时间复杂度 \(\mathcal{O(n^3)}\).
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
typedef long long ll;
const ll mod = 1000000007;
template <typename T> T Add(T x, T y) { return (x + y >= mod) ? (x + y - mod) : (x + y); }
template <typename T> T cAdd(T x, T y) { return x = (x + y >= mod) ? (x + y - mod) : (x + y); }
template <typename T> T Mul(T x, T y) { return x * y % mod; }
template <typename T> T Mod(T x) { return x < 0 ? (x + mod) : x; }
template <typename T> T Max(T x, T y) { return x > y ? x : y; }
template <typename T> T Min(T x, T y) { return x < y ? x : y; }
template <typename T> T Abs(T x) { return x < 0 ? -x : x; }
template <typename T> T chkmax(T &x, T y) { return x = x > y ? x : y; }
template <typename T>
T &read(T &r) {
r = 0; bool w = 0; char ch = getchar();
while(ch < '0' || ch > '9') w = ch == '-' ? 1 : 0, ch = getchar();
while(ch >= '0' && ch <= '9') r = (r << 3) + (r <<1) + (ch ^ 48), ch = getchar();
return r = w ? -r : r;
}
ll qpow(ll x, ll y) {
ll sumq = 1;
while(y) {
if(y & 1) sumq = sumq * x % mod;
x = x * x % mod;
y >>= 1;
}
return sumq;
}
const int N = 110;
int n, m;
int len[N], f[N][N], s[N][N][N];
signed main() { //freopen("in.txt", "r", stdin); freopen("out.txt", "w", stdout);
read(n); read(m);
for(int i = 1; i <= n; ++i) {
read(len[i]);
for(int j = 1; j <= len[i]; ++j) {
int x, y; read(x); read(y);
for(int k = x; k <= y; ++k)
++s[k][x][y];
}
}
for(int k = 1; k <= m; ++k) {
for(int l = m; l; --l)
for(int r = 1; r <= m; ++r)
s[k][l][r] += s[k][l][r-1];
for(int l = m; l; --l)
for(int r = 1; r <= m; ++r)
s[k][l][r] += s[k][l+1][r];
}
for(int k = 1; k <= m; ++k) f[k][k] = s[k][k][k] * s[k][k][k];
for(int len = 1; len < m; ++len)
for(int i = 1; i + len <= m; ++i) {
int j = i + len;
for(int k = i; k <= j; ++k) f[i][j] = Max(f[i][j], f[i][k-1] + s[k][i][j] * s[k][i][j] + f[k+1][j]);
}
printf("%d\n", f[1][m]);
return 0;
}