一、题目:
二、思路:
首先有一个非常简单的DP思路:设DP状态为 \(dp[i, j]\),表示把前 \(j\) 个元素分成 \(i\) 个部分所需要的最小花费。则有状态转移方程
\[dp[i,j]=\min\limits_{i\leq j'\leq j} \{dp[i-1,j'-1]+cost(j',j)\} \]在这里,\(i\) 是阶段,\(i\) 和 \(j\) 共同构成状态,\(j'\) 是决策。
接下来我们证明该DP具有决策单调性。首先让我们固定 \(i\),然后设 \(p_j\) 是 \(j\) 的最优决策点(如果有多个一样优的决策点,取最左边的一个)。要证明 \(p_j\) 具有决策单调性,只需证 \(cost\) 具有四边形不等式。
即证
\[\forall a<b,cost(a,b+1)+cost(a+1,b)\geq cost(a,b)+cost(a+1,b+1) \]即证
\[\forall a<b,cost(a, b+1)-cost(a,b)\geq cost(a+1,b+1)-cost(a+1,b) \]不等式左边可以看成是将元素 \(b+1\) 放入区间 \([a,b]\) 对 \(cost\) 的影响,不等式右边可以看成是将元素 \(b+1\) 放入区间 \([a+1,b]\) 对 \(cost\) 的影响,显然左边大于等于右边。因此原命题得证。
于是这道题就可以使用整体二分解决了!当然,为了保证复杂度,我们要保证每次计算 \(cost\) 的复杂度。这其实可以通过维护一个桶 \(cnt\),每次计算区间端点改变对 \(cnt\) 的影响即可。
时间复杂度 \(O(kn\log n)\)。
三、代码:
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
#define FILEIN(s) freopen(s".in", "r", stdin);
#define FILEOUT(s) freopen(s".out", "w", stdout)
#define mem(s, v) memset(s, v, sizeof s)
inline int read(void) {
int x = 0, f = 1; char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
return f * x;
}
const int maxn = 1e5 + 5, maxk = 25;
const long long inf = 0x3f3f3f3f3f3f3f3f;
int n, K;
int cnt[maxn];
long long cost, dp[maxk][maxn];
int a[maxn];
int lastl, lastr;
void get(int l, int r) {
while (lastr > r) -- cnt[a[lastr]], cost -= cnt[a[lastr]], -- lastr;
while (lastr < r) ++ lastr, cost += cnt[a[lastr]], ++ cnt[a[lastr]];
while (lastl < l) -- cnt[a[lastl]], cost -= cnt[a[lastl]], ++ lastl;
while (lastl > l) -- lastl, cost += cnt[a[lastl]], ++ cnt[a[lastl]];
}
void solve(int L, int R, int l, int r, int id) {
if (L > R) return;
if (l > r) return;
int mid = (l + r) >> 1;
int pos = 0; long long minn = inf;
for (int i = max(L, id); i <= min(R, mid); ++ i) {
get(i, mid);
if (dp[id - 1][i - 1] + cost < minn) {
minn = dp[id - 1][i - 1] + cost; pos = i;
}
}
dp[id][mid] = minn;
solve(L, pos, l, mid - 1, id); solve(pos, R, mid + 1, r, id);
}
int main() {
n = read(); K = read();
for (int i = 1; i <= n; ++ i) {
a[i] = read();
}
mem(dp, 0x3f);
dp[0][0] = 0;
lastl = 1; lastr = 0;
for (int id = 1; id <= K; ++ id) {
solve(1, n, 1, n, id);
}
printf("%lld\n", dp[K][n]);
return 0;
}