将选取的$A$看成左括号,$B$看成右括号,那么答案是一个合法的括号序列。
那么只要重复取出$k$对价值最小的左右括号,保证每时每刻都是一个合法的括号序列即可。
将$($看成$1$,$)$看成$-1$,设$s[]$为前缀和。
如果当前取出的是$()$,那么对前缀和的影响为$[A,B-1]$区间加$1$。
如果当前取出的是$)($,那么对前缀和的影响为$[B,A-1]$区间减$1$,所以这种情况需要满足区间$s$的最小值不为$0$。
考虑用线段树维护这个序列,线段树上每个节点维护以下信息:
va:$A\leq B$情况的最优解。
vb:$A>B$情况的最优解,且满足$[A,B-1]$的区间$s$最小值大于当前区间的$s$最小值。
vc:$A>B$情况的最优解。
aa:区间内代价最小的$A$。
ab:区间内代价最小的$B$。
ba:区间内代价最小的$A$,满足$[st,A-1]$的区间$s$最小值大于区间$s$最小值。
bb:区间内代价最小的$B$,满足$[B,en]$的区间$s$最小值大于区间$s$最小值。
vm:区间$s$最小值。
tag:区间增量标记。
为了方便维护,可以考虑增加第$0$项,$A[0]=B[0]=inf$。
那么$[0,n]$区间的$vb$必定满足区间最小值不为$0$,然后贪心选取$k$次即可求出最优解。
时间复杂度$O(k\log n)$。
#include<cstdio>
const int N=500010,M=1050000,inf=1000000010;
long long ans;
int n,k,i,j,A[N],B[N],aa[M],ab[M],ba[M],bb[M],vm[M],tag[M];
struct P{
int x,y;
P(){}
P(int _x,int _y){x=_x,y=_y;}
P operator+(const P&b){return A[x]+B[y]<A[b.x]+B[b.y]?*this:b;}
}va[M],vb[M],vc[M],t;
inline void read(int&a){char c;while(!(((c=getchar())>='0')&&(c<='9')));a=c-'0';while(((c=getchar())>='0')&&(c<='9'))(a*=10)+=c-'0';}
inline void add1(int x,int p){vm[x]+=p;tag[x]+=p;}
inline void pb(int x){if(tag[x])add1(x<<1,tag[x]),add1(x<<1|1,tag[x]),tag[x]=0;}
inline void up(int x){
int l=x<<1,r=l|1;
va[x]=va[l]+va[r]+P(aa[l],ab[r]);
vc[x]=vc[l]+vc[r]+P(aa[r],ab[l]);
vb[x]=vb[l]+vb[r];
aa[x]=A[aa[l]]<A[aa[r]]?aa[l]:aa[r];
ab[x]=B[ab[l]]<B[ab[r]]?ab[l]:ab[r];
if(vm[l]<vm[r]){
vb[x]=vb[x]+vc[r]+P(aa[r],bb[l]);
ba[x]=ba[l];
bb[x]=B[ab[r]]<B[bb[l]]?ab[r]:bb[l];
vm[x]=vm[l];
}
if(vm[l]>vm[r]){
vb[x]=vb[x]+vc[l]+P(ba[r],ab[l]);
ba[x]=A[aa[l]]<A[ba[r]]?aa[l]:ba[r];
bb[x]=bb[r];
vm[x]=vm[r];
}
if(vm[l]==vm[r]){
vb[x]=vb[x]+P(ba[r],bb[l]);
ba[x]=ba[l];
bb[x]=bb[r];
vm[x]=vm[l];
}
}
void build(int x,int a,int b){
if(a==b){
va[x]=vc[x]=P(a,a),vb[x]=P(0,0);
aa[x]=ab[x]=ba[x]=a;
return;
}
int mid=(a+b)>>1;
build(x<<1,a,mid),build(x<<1|1,mid+1,b),up(x);
}
void add(int x,int a,int b,int c,int d,int p){
if(c<=a&&b<=d){add1(x,p);return;}
pb(x);
int mid=(a+b)>>1;
if(c<=mid)add(x<<1,a,mid,c,d,p);
if(d>mid)add(x<<1|1,mid+1,b,c,d,p);
up(x);
}
void change(int x,int a,int b,int c){
if(a==b)return;
pb(x);
int mid=(a+b)>>1;
if(c<=mid)change(x<<1,a,mid,c);else change(x<<1|1,mid+1,b,c);
up(x);
}
int main(){
read(n),read(k);
for(i=1;i<=n;i++)read(A[i]);
for(i=1;i<=n;i++)read(B[i]);
A[0]=B[0]=inf;
build(1,0,n);
while(k--){
t=va[1]+vb[1],i=t.x,j=t.y,ans+=A[i]+B[j];
if(i<j)add(1,0,n,i,j-1,1);
if(i>j)add(1,0,n,j,i-1,-1);
A[i]=inf,change(1,0,n,i);
B[j]=inf,change(1,0,n,j);
}
return printf("%lld",ans),0;
}