算是左偏树的模板题。真正的模板题懒得放了。首先总结一下什么是左偏树。
左偏树就是可并堆。假如想要把两个普通堆进行合并,最简单的方法就是把较小堆的元素一个一个假如较大堆中,应该是\(O(NlogN)\)左右吧。但我们发现这样的做法非常费时间,何不考虑直接把一个堆直接挂到另一个堆的下面呢。然后发现这样做可能会导致堆结构的山崩地裂,于是为了维护它的结构,就有了左偏树。
先说一下,这个偏左偏右似乎并没有什么关系,爱偏向哪边偏哪边。只不过既然大家都写的是左偏树,我也就没有尝试右偏树的写法了。
左偏树的代码其实挺优雅的。核心代码(就是合并部分)如下:
#define lc t[x].l
#define rc t[x].r
int merge(int x,int y){
if(!x||!y)return x+y;
if(t[y]<t[x])swap(x,y);
rc=merge(rc,y);
if(t[lc].dis<t[rc].dis)swap(lc,rc);
t[x].dis=t[rc].dis+1;return x;
}
和线段树合并有异曲同工之妙。第一行都是判断假如有一个空点,那么直接返回非空点编号即可(合并了个寂寞)。然后就考虑,由于左偏树普遍偏左(这不废话),贪心地把合并上来的堆挂在右子树上,进行递归处理即可。dis本质上是在维护偏的程度,只要保持左孩子的dis比右孩子大就可以保证这棵树一直左偏,这也就能保证合并的复杂度了。最后两行代码是在检查左右孩子哪个偏得更厉害。
然后就是说这道题了,这道题用到了一点点贪心的思想。一看到什么有且仅有子树内的点可能成为贡献点之类的话就应该想到递归合并,要么线段树合并要么堆合并。由于\(ans=w[i]\times num\),肯定是要最大化\(num\)的值。而总费用是一定的,那就是要贪心地找出子树内最小的那些节点即可。实现上可以维护一个大根堆,统计答案之前把堆顶元素一直删除直到堆内元素之和不大于费用。正确性也很好说,毕竟一个被弹出的元素不可能再成为贡献点了。
代码(发现我越来越习惯压行了):
#include<cstdio>
//#define zczc
#define ll long long
const int N=100010;
inline void read(int &wh){
wh=0;int f=1;char w=getchar();
while(w<'0'||w>'9'){if(w=='-')f=-1;w=getchar();}
while(w<='9'&&w>='0'){wh=wh*10+w-'0';w=getchar();}
wh*=f;return;
}
inline void swap(int &s1,int &s2){int s3=s1;s1=s2;s2=s3;return;}
int m,n,c[N],w[N];
struct edge{
int t,next;
}e[N];
int esum,head[N];
inline void add(int fr,int to){
e[++esum]=(edge){to,head[fr]};head[fr]=esum;
}
#define lc t[x].l
#define rc t[x].r
struct node{int l,r,v,dis,num;ll sum;}t[N];
int merge(int x,int y){
if(!x||!y)return x+y;
if(t[x].v<t[y].v)swap(x,y);rc=merge(rc,y);
t[x].sum=t[lc].sum+t[rc].sum+t[x].v;
t[x].num=t[lc].num+t[rc].num+1;
if(t[lc].dis<t[rc].dis)swap(lc,rc);
t[x].dis=t[rc].dis+1;return x;
}
int del(int x){return merge(lc,rc);}
#undef lc
#undef rc
int cnt,r[N];
ll ans,now;
void solve(int wh){
t[r[wh]=++cnt]=(node){0,0,c[wh],0,1,c[wh]};
for(int i=head[wh],th;i;i=e[i].next)
solve(th=e[i].t),r[wh]=merge(r[wh],r[th]);
while(t[r[wh]].sum>n)r[wh]=del(r[wh]);
ans=(now=(ll)t[r[wh]].num*w[wh])>ans?now:ans;
}
signed main(){
#ifdef zczc
freopen("in.txt","r",stdin);
#endif
int s1,root;read(m);read(n);
for(int i=1;i<=m;i++){
read(s1);add(s1?s1:root=s1,i);read(c[i]);read(w[i]);
}
solve(root);printf("%lld",ans);
return 0;
}