CF868F Yet Another Minimization Problem
解法
这个题首先考虑最基础的 DP。
显然,我们可以令 \(f_{i,j}\) 表示,考虑区间 \([1,i]\),分成 \(j\) 段的最小费用。那么我们最后求得就是 \(f_{n,k}\)。
转移也显然:
\[f_{i,j}=\min\limits_{t=1}^i f_{t-1,j-1}+w_{t,i} \]其中 \(w_{t,i}\) 表示 \([t,i]\) 区间内的费用。
这个 DP 显然 \(O(k\cdot n^2)\),要炸。由于 \(k\) 很小,我们考虑优化到 \(O(k\cdot n\log n)\)。CF的少爷机肯定跑的了。
我们考虑用整体二分优化。
首先理解一点,我们令 \(g_{i,j}\) 为 \(f_{i,j}\) 的决策点,即,使得 \(f_{i,j}\) 最小的那个 \(t\)。我们可以理性理解一下,\(j\) 不变,随着 \(i\) 增大,\(g_{i,j}\) 一定单调不严格递增。想要证明的话可以用四边形不等式。这个读者自证不难。
其实我自己证过,是对的。
那我们考虑固定 \(j\),计算区间 \([l,r]\) 内的 \(f_{i,j}\) 的值。我们现在暴力求解 \(f_{mid,j}\)。为了表达方便,记 \(p=g_{mid,j}\)。由于我们上面提到的那个性质,\(\forall t\in [l,mid-1]\) 的决策点一定在 \([l,p]\) 中,右半边同理。
这样,我们不停递归即可。我们可以发现,最多递归 \(\log n\) 层,每层复杂度是 \(O(n)\),再算上枚举 \(j\) 的时间,总复杂度为 \(O(k\cdot n\log n)\)。
至于怎么计算 \(w_{i,j}\),可以类比莫队。
代码
//Don't act like a loser.
//This code is written by huayucaiji
//You can only use the code for studying or finding mistakes
//Or,you'll be punished by Sakyamuni!!!
#include<bits/stdc++.h>
#define int long long
using namespace std;
int read() {
char ch=getchar();
int f=1,x=0;
while(ch<'0'||ch>'9') {
if(ch=='-')
f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9') {
x=x*10+ch-'0';
ch=getchar();
}
return f*x;
}
const int MAXN=1e5+10,INF=1e14;
int n,m,nl,nr,tot;
int f[MAXN][22],a[MAXN],sum[MAXN];
int calc(int x) {
return x*(x-1)/2;
}
void get(int l,int r) {
//类比莫队
while(nl<l) {
tot-=calc(sum[a[nl]]);
sum[a[nl]]--;
tot+=calc(sum[a[nl]]);
nl++;
}
while(nl>l) {
nl--;
tot-=calc(sum[a[nl]]);
sum[a[nl]]++;
tot+=calc(sum[a[nl]]);
}
while(nr<r) {
nr++;
tot-=calc(sum[a[nr]]);
sum[a[nr]]++;
tot+=calc(sum[a[nr]]);
}
while(nr>r) {
tot-=calc(sum[a[nr]]);
sum[a[nr]]--;
tot+=calc(sum[a[nr]]);
nr--;
}
}
void divide(int l,int r,int lp,int rp,int j) {
//[lp,rp]为决策点的可能区间范围。
if(r<l) {
return ;
}
int mark=0;
int mid=(l+r)>>1;
int rrpp=min(rp,mid-1);
f[mid][j]=INF;
for(int p=lp;p<=rrpp;p++) {
get(p+1,mid);
//f[mid][j]=min(f[mid][j],f[p][j-1]+tot);
if(f[mid][j]>f[p][j-1]+tot) {
f[mid][j]=f[p][j-1]+tot;
mark=p;
}
}
divide(l,mid-1,lp,mark,j);
divide(mid+1,r,mark,rp,j);
}
signed main() {
cin>>n>>m;
nl=1;
for(int i=1;i<=n;i++) {
a[i]=read();
}
for(int i=1;i<=n;i++) {
get(1,i);
f[i][1]=tot;
//j=1特别计算
}
for(int i=2;i<=m;i++) {
divide(1,n,1,n,i);
}
cout<<f[n][m]<<endl;
return 0;
}