内容
\(wqs\) 二分又名凸优化、带权二分。
一般用于 \(n\) 个物品强制选 \(k\) 个的情况下的最优化问题。
这样的问题直接 \(dp\) 复杂度一般都比较高,因为要求强制选 \(k\) 个,所以要有一维来记录选了多少物品。
而 \(wqs\) 二分则可以把这种限制去掉。
首先我们二分一个权值 \(C\),强行给每一个物品都加上这一个权值。
然后跑一遍没有选 \(k\) 个物品的限制的 \(dp\)。
最后根据最优值所选择的物品个数来调整二分端点。
能够用 \(wqs\) 二分优化的 \(dp\) 要满足 \(dp\) 得到的结果是凸的。
也就是说,如果把横坐标看作强制选择的物品个数,纵坐标看作函数值,那么相邻两点之间的斜率应该是单调的。
之所以要有这个限制,是因为我们二分的附加权值实际上是斜率。
假设要求的是最大值,我们拿一条斜率为 \(k\) 的直线去切这个凸包,那么切到的点的截距一定是最大的。
但是我们并不知道我们具体切到了哪一个点,所以需要去计算。
根据直线的斜截式 \(y=kx+b\),截距 \(b=y-kx\) 。
我们可以把截距 \(b\) 也看成一个一次函数,那么如果能求出 \(b\) 的最值也就知道了当前切到的点的横坐标。
观察 \(b\) 的表达式,实际上就相当于给每一种物品减去了一个权值。
所以我们只要给物品减去权值之后跑一次不带限制的 \(dp\),求出最优的情况了选择了几个物品,就能知道切到的是哪一个点了。
但是有的时候会出现斜率相等的情况,这是就需要我们强制规定选横坐标最大/小的点。
例题
P2619 [国家集训队2]Tree I
分析
求恰好有 \(k\) 条白边的最小生成树。
可以给每一条白边加上额外的边权去跑最小生成树。
如果得到的生成树中白边比想要的多,就说明我们加的权值少了,要多加点,否则就少加点。
在斜率相等时,我们强制选择白边,也就是横坐标最大的点。
代码
#include<cstdio>
#include<algorithm>
#define rg register
inline int read(){
rg int x=0,fh=1;
rg char ch=getchar();
while(ch<'0' || ch>'9'){
if(ch=='-') fh=-1;
ch=getchar();
}
while(ch>='0' && ch<='9'){
x=(x<<1)+(x<<3)+(ch^48);
ch=getchar();
}
return x*fh;
}
const int maxn=1e5+5;
int n,m,k,ans,fa[maxn],cnt,whicnt,sum;
struct asd{
int zb,yb,val,jud;
}b[maxn];
bool cmp(rg asd aa,rg asd bb){
return aa.val==bb.val?aa.jud<bb.jud:aa.val<bb.val;
}
int zhao(rg int xx){
if(xx==fa[xx]) return xx;
return fa[xx]=zhao(fa[xx]);
}
bool jud(rg int val){
ans=cnt=whicnt=0;
for(rg int i=1;i<=m;i++) b[i].val+=(!b[i].jud)*val;
for(rg int i=1;i<=n;i++) fa[i]=i;
std::sort(b+1,b+m+1,cmp);
rg int aa,bb;
for(rg int i=1;i<=m;i++){
aa=b[i].zb,bb=b[i].yb;
aa=zhao(aa),bb=zhao(bb);
if(aa==bb) continue;
whicnt+=(b[i].jud==0),cnt++,fa[aa]=bb,ans+=b[i].val;
if(cnt==n-1) break;
}
for(rg int i=1;i<=m;i++) b[i].val-=(!b[i].jud)*val;
return whicnt>=k;
}
int main(){
n=read(),m=read(),k=read();
for(rg int i=1;i<=m;i++) b[i].zb=read()+1,b[i].yb=read()+1,b[i].val=read(),b[i].jud=read();
rg int l=-200,r=200,mids;
while(l<=r){
mids=(l+r)>>1;
if(jud(mids)){
l=mids+1;
sum=ans-k*mids;
} else {
r=mids-1;
}
}
printf("%d\n",sum);
return 0;
}
P4383 [八省联考2018]林克卡特树
分析
实际上是让你从树上选择 \(k+1\) 条不相交的链,使权值最大。
考虑 \(60\) 分的 \(dp\) 做法。
设 \(f[i][j][0/1/2]\) 为在 \(i\) 的子树中选择了 \(j\) 条链,\(i\) 的度数为 \(0,1,2\) 时的最大值。
之所以要加上度数的限制是为了合并子树的时候能够更好地处理信息。
度数为 \(0\) 代表当前点不在链上
度数为 \(1\) 代表当前点是链的一个端点
度数为 \(2\) 代表当前点在一条链的中心
每一次转移之后,我们都令 \(f[now][j][0]=max(f[now][j][0],max(f[now][j][2],f[now][j-1][1]))\)
这样我们在更新父亲节点的时候就不用特判很多情况
设 \(u\) 为 \(now\) 的儿子,\(val\) 代表边权
则 \(f[now][j][2]\) 可以由 \(f[now][k][2]+f[u][j-k][0]\) 和 \(f[now][k][1]+f[u][j-k-1][1]+val\) 更新而来
含义分别是继承之前的信息,当前点所在的链的一段与儿子节点所在的链的一端拼和成一条新的链并且当前点处在链的*
\(f[now][j][1]\) 可以由 \(f[now][k][1]+f[u][j-k][0]\) 和 \(f[now][k][0]+f[u][j-k-1][1]+val\) 更新而来
含义分别是继承之前的信息,当前边与儿子节点所在的链的一端拼和成一条新的链并且让当前节点作为链的一端
\(f[now][j][0]\) 直接继承 \(f[now][k][0]+f[u][j-k][0]\) 即可
一开始的时候要把一个节点也当链处理,即 \(f[now][0][0]=f[now][0][1]=f[now][1][2]=0\)
打表可得函数值是一个凸函数,斜率单调不增
所以可以用 \(wqs\) 二分优化
每次强制给每一条链加上一个权值,算一下最优的情况下选择了多少链
如果选择的链比想要的多,那么增加附加权值,少选一些
否则减小附加权值,多选一些
斜率相等的时候强制选择最左边的点
注意一下数组更新的顺序就行了
代码
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#define rg register
inline int read(){
rg int x=0,fh=1;
rg char ch=getchar();
while(ch<'0' || ch>'9'){
if(ch=='-') fh=-1;
ch=getchar();
}
while(ch>='0' && ch<='9'){
x=(x<<1)+(x<<3)+(ch^48);
ch=getchar();
}
return x*fh;
}
const int maxn=3e5+5;
typedef long long ll;
int h[maxn],tot=1,n,k;
struct asd{
int to,nxt,val;
}b[maxn<<1];
void ad(rg int aa,rg int bb,rg int cc){
b[tot].to=bb;
b[tot].nxt=h[aa];
b[tot].val=cc;
h[aa]=tot++;
}
struct jie{
int cnt;
ll val;
jie(){}
jie(rg int aa,rg ll bb){
cnt=aa,val=bb;
}
friend jie operator + (const jie& A,const jie& B){
return jie(A.cnt+B.cnt,A.val+B.val);
}
friend bool operator < (const jie& A,const jie& B){
if(A.val==B.val) return A.cnt<B.cnt;
return A.val<B.val;
}
}f[maxn][3];
jie Max(rg jie aa,rg jie bb){
return aa<bb?bb:aa;
}
void dfs(rg int now,rg int lat,rg ll val){
f[now][1]=f[now][0]=jie(0,0);
f[now][2]=jie(1,val);
for(rg int i=h[now];i!=-1;i=b[i].nxt){
rg int u=b[i].to;
if(u==lat) continue;
dfs(u,now,val);
f[now][2]=Max(f[now][2]+f[u][0],f[now][1]+f[u][1]+jie(1,b[i].val+val));
f[now][1]=Max(f[now][0]+f[u][1]+jie(0,b[i].val),f[now][1]+f[u][0]);
f[now][0]=Max(f[now][0],f[now][0]+f[u][0]);
}
f[now][0]=Max(f[now][0],Max(f[now][2],f[now][1]+jie(1,val)));
}
void init(){
for(rg int i=1;i<=n;i++){
f[i][0].cnt=f[i][1].cnt=f[i][2].cnt=0;
f[i][0].val=f[i][1].val=f[i][2].val=-0x3f3f3f3f3f3f3f3f;
}
}
int main(){
memset(h,-1,sizeof(h));
n=read(),k=read();
rg int aa,bb,cc;
for(rg int i=1;i<n;i++){
aa=read(),bb=read(),cc=read();
ad(aa,bb,cc);
ad(bb,aa,cc);
}
k++;
rg long long l=-3e11,r=3e11,mids,ans;
while(l<=r){
mids=(l+r)>>1;
init();
dfs(1,0,mids);
if(f[1][0].cnt<k){
l=mids+1;
} else {
ans=f[1][0].val-1LL*mids*k;
r=mids-1;
}
}
printf("%lld\n",ans);
return 0;
}