原题链接:https://codeforces.ml/contest/1527/problem/E
题意
给了一个长度为n的序列,其中 C o s t ( t ) = ∑ x ∈ s e t ( t ) l a s t ( x ) − f i r s t ( x ) Cost(t)=\sum_{x∈set(t)}last(x)-first(x) Cost(t)=∑x∈set(t)last(x)−first(x),我们可以将序列分成k段,问minCost是多少
分析
不难想到dp的状态 d p [ i ] [ j ] dp[i][j] dp[i][j]代表前i个数分成j组时的最小花费,然后先推出一个暴力的DP方程
d p [ i ] [ j ] = m i n ( d [ k ] [ j − 1 ] + v a l ( k + 1 , i ) ) k ∈ [ 0 , j − 1 ] dp[i][j] = min(d[k][j-1]+val(k+1,i))k∈[0, j-1] dp[i][j]=min(d[k][j−1]+val(k+1,i))k∈[0,j−1]
v a l ( i , j ) 代 表 从 [ i , j ] 区 间 的 花 费 val(i,j)代表从[i,j]区间的花费 val(i,j)代表从[i,j]区间的花费
如果直接暴力去找肯定是超时的,这时候就可以用数据结构去优化DP。首先考虑怎么算区间内的花费,我们记录一个last[x]表示这个数前一次出现的位置,然后存入 i − l a s t [ x ] i-last[x] i−last[x],统计区间和,这样就可以算出每个数最晚出现和最早出现之差,但这样是有问题的,因为有些数的 l a s t [ x ] < k + 1 last[x]<k+1 last[x]<k+1,这样的值对于区域是没有贡献的,因此我们倒过来考虑,去累加当前x对哪些区间有影响。 l a s t [ x ] > = k + 1 last[x]>=k+1 last[x]>=k+1推出 k < = l a s t [ x ] − 1 k<=last[x]-1 k<=last[x]−1,也就是说当前i对 [ 0 , l a s t [ x ] − 1 ] [0,last[x]-1] [0,last[x]−1]有 i − l a s t [ x ] i-last[x] i−last[x]的贡献。
Code
#include <bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define re register
typedef long long ll;
typedef pair<ll, ll> PII;
typedef unsigned long long ull;
const int N = 35005 + 20, M = 1e6 + 5, INF = 0x3f3f3f3f;
const int MOD = 1e9+9;
int dp[N];
int a[N], last[N], pre[N];
struct node {
int l, r;
int sum, tag;
}t[N<<2];
void push_up(int u) {
t[u].sum = min(t[u<<1].sum, t[u<<1|1].sum);
}
void push_down(int u) {
if (t[u].tag) {
t[u<<1].tag += t[u].tag;
t[u<<1|1].tag += t[u].tag;
t[u<<1].sum += t[u].tag;
t[u<<1|1].sum += t[u].tag;
t[u].tag = 0;
}
}
void build(int u, int l, int r) {
t[u].l = l, t[u].r = r, t[u].tag = 0, t[u].sum = INF;
if (l == r) {
t[u].sum = dp[l];
return;
}
int mid = (l + r) >> 1;
build(u<<1, l, mid);
build(u<<1|1, mid+1, r);
push_up(u);
}
void modify(int u, int ql, int qr, int val) {
if (ql <= t[u].l && qr >= t[u].r) {
t[u].sum += val;
t[u].tag += val;
return;
}
push_down(u);
int mid = (t[u].l + t[u].r) >> 1;
if (ql <= mid) modify(u<<1, ql, qr, val);
if (qr > mid) modify(u<<1|1, ql, qr, val);
push_up(u);
}
int query(int u, int ql, int qr) {
if (ql <= t[u].l && qr >= t[u].r) return t[u].sum;
int mid = (t[u].l + t[u].r) >> 1;
int ans = INF;
push_down(u);
if (ql <= mid) ans = min(ans, query(u<<1, ql, qr));
if (qr > mid) ans = min(ans, query(u<<1|1, ql, qr));
return ans;
}
void solve() {
int n, k; cin >> n >> k;
for (int i = 1; i <= n; i++) {
cin >> a[i];
if (!last[a[i]]) pre[i] = i;
else pre[i] = last[a[i]];
last[a[i]] = i;
}
memset(dp, 0x3f, sizeof dp);
dp[0] = 0;
for (int i = 1; i <= k; i++) {
build(1, 0, n);
for (int j = 1; j <= n; j++) {
int p = pre[j];
int val = j - p;
modify(1, 0, p-1, val);
dp[j] = query(1, 0, j-1);
}
}
cout << dp[n] << endl;
}
signed main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
#ifdef ACM_LOCAL
freopen("input", "r", stdin);
freopen("output", "w", stdout);
#endif
solve();
}