题意
给出一棵n个点的树,求包含1号点的第k小的连通块权值和。(\(n<=10^5\))
分析
k小一般考虑堆...
题解
堆中关键字为\(s(x)+min(a)\),其中\(s(x)\)表示\(x\)状态的权值和,\(min(a)\)表示\(x\)状态相邻的不在\(x\)里的的点的最小权值。
每一次从堆中弹出最小的,然后用这个来拓展。
可以证明,这样第\(k\)次弹出来的状态\(x \cup \\{ a \\}\)就是\(k\)小的。
证明很简单,堆中的都是待拓展状态,每一次取出来\(x \cup \\{ a \\}\)的将来再也不会有状态的权值和小于这个状态。
维护相邻的点我们可以用可持久化堆搞。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MX=8000005; //超过这个数这份代码别想活了【捂脸熊
struct node *null;
struct node {
node *c[2];
int x, w, s;
void up() { s=1+c[0]->s+c[1]->s; if(c[0]->s<c[1]->s) swap(c[0], c[1]); }
}pool[MX], *iT=pool;
int cc=2;
node *newnode() {
node *x=iT++; ++cc; if(cc>=MX) { puts("err"); exit(0); }
x->c[0]=x->c[1]=null; x->x=-1; x->w=0; x->s=0;
return x;
}
void cpy(node *x, node *y) {
x->x=y->x; x->w=y->w; x->s=y->s;
x->c[0]=y->c[0]; x->c[1]=y->c[1];
}
void init() {
null=iT++;
null->c[0]=null->c[1]=null;
null->x=0; null->w=0; null->s=0;
}
node *merge(node *l, node *r) {
if(l==null) return r;
if(r==null) return l;
node *x=newnode();
if(l->w<=r->w) { cpy(x, l); x->c[1]=merge(x->c[1], r); }
else { cpy(x, r); x->c[1]=merge(x->c[1], l); }
x->up();
return x;
}
node *ins(node *x, node *y) {
if(x==null) return y;
node *p=x;
if(y->w<=x->w) p=y, p->c[0]=x;
else x->c[1]=ins(x->c[1], y);
p->up();
return p;
}
node *ins(node *x, int id, int w, int flag=1) {
node *y=newnode(); y->x=id; y->w=w; y->s=1;
return flag?merge(x, y):ins(x, y);
}
node *del(node *x) { return merge(x->c[0], x->c[1]); }
struct ip {
node *x; ll sum;
bool operator<(const ip &a) const { return sum<a.sum; }
};
multiset<ip> q;
const int N=100005;
int ihead[N], f[N], cnt, n, K;
struct E { int next, to, w; }e[N<<1];
void add(int x, int y, int w) {
e[++cnt]=(E){ihead[x], y, w}; ihead[x]=cnt;
e[++cnt]=(E){ihead[y], x, w}; ihead[y]=cnt;
}
node *root[N];
int main() {
init();
scanf("%d%d", &n, &K);
for(int i=2; i<=n; ++i) {
int w;
scanf("%d%d", &f[i], &w);
add(i, f[i], w);
}
node *rt=null;
for(int x=1; x<=n; ++x) {
root[x]=null;
for(int i=ihead[x]; i; i=e[i].next) if(e[i].to!=f[x])
root[x]=ins(root[x], e[i].to, e[i].w, 0);
}
root[n+1]=ins(null, 1, 0, 0);
q.insert((ip){root[n+1], 0});
ll ans=0; int cnt=0; ip t;
multiset<ip>::iterator it;
while(1) {
if(q.size()==0) break;
it=q.begin(); q.erase(it);
t=*it;
ans=t.sum;
node *now=t.x;
if(now==null) continue;
rt=del(now); if(rt!=null) q.insert((ip){rt, t.sum-now->w+rt->w});
rt=merge(rt, root[now->x]);
q.insert((ip){rt, t.sum+rt->w});
++cnt; if(cnt==K) { ans=t.sum; break; }
if((int)q.size()>K) { it=q.end(); q.erase(--it); }
}
printf("%lld\n", ans%998244353);
return 0;
}