http://uoj.ac/problem/205
好神的题啊。
dp[i][j]表示以i为根的子树调整成长度j需要的最小代价。
首先要观察到dp值是一个下凸壳。
因为从儿子合并到父亲时要把所有儿子的凸壳相加,得到的还是一个凸壳。
父亲要把它连向它父亲的边的影响加入时,设这条边长度为len,则相当于把当前的这个凸壳先右移len,斜率大于1的部分斜率都重置为1,斜率小于1的部分都向左移len再向上移len,其中空出来的长度为len的部分用斜率为-1的连接起来。
就是把原凸壳先整体上移len,再删掉斜率大于等于0的部分,再添上3条斜率分别为-1,0,1的直线。
直接维护凸壳的复杂度是\(O\left((n+m)^2\right)\)的。
再来考虑一下凸壳的性质:
一个凸壳在x=0处的取值是子树内所有边权和;
当这个凸壳没有考虑当前点到它父亲的边的贡献时,这个凸壳最右端的直线的斜率是它的儿子数;
凸壳上的直线的斜率只可能是整数;
现在有了上面的性质,可以更简单地表达一个凸壳。
有了凸壳在x=0处的取值,我们只要知道一个凸壳的导函数就可以还原出一个凸壳。
有了凸壳最右端直线的斜率,也就是导函数的最大值,我们只要知道一个凸壳的二阶导就可以还原出凸壳的导函数。
也就是说不用维护凸壳,直接维护凸壳的二阶导数就可以了。
二阶导可以更直观的看成拐点,每个在第i个位置的拐点对二阶导的贡献为1(拐点的位置可以重叠)。
每次合并时直接合并两个拐点集合就可以了,每次考虑父亲边的贡献时删掉最靠右边的儿子数+1个拐点,再添加两个拐点。
因为每次都删权值最大的拐点,拐点集合可以用可并堆维护。
最后用根节点的拐点集合还原出根节点的凸壳就可以了。
每个节点只可能加进来两个拐点,每个拐点最多被弹出一次,时间复杂度\(O\left((n+m)\log(m+m)\right)\)。
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N = 600003;
struct node {
node *ch[2];
int v; ll pos;
node(ll _pos = 0) : pos(_pos) {ch[0] = ch[1] = 0; v = 0;}
} *rt[N];
int dist(node *r) {return r ? r->v : -1;}
node *merge(node *l, node *r) {
if (l == 0) return r;
if (r == 0) return l;
if (l->pos < r->pos) swap(l, r);
l->ch[1] = merge(l->ch[1], r);
if (dist(l->ch[0]) < dist(l->ch[1]))
swap(l->ch[0], l->ch[1]);
if (l->ch[1]) l->v = l->ch[1]->v + 1;
else l->v = 0;
return l;
}
void pop(node *&r) {
if (r) r = merge(r->ch[0], r->ch[1]);
}
ll sum = 0, pp[N << 1];
int fa[N << 1], len[N << 1], n, m, d[N << 1];
int main() {
scanf("%d%d", &n, &m);
for (int i = 2; i <= n + m; ++i) {
scanf("%d%d", fa + i, len + i);
sum += len[i];
++d[fa[i]];
}
node *n1, *n2;
for (int i = n + m; i > 1; --i) {
if (i > n) {
rt[i] = merge(new node(len[i]), new node(len[i]));
rt[fa[i]] = merge(rt[fa[i]], rt[i]);
continue;
}
while (--d[i]) pop(rt[i]);
n2 = rt[i]; pop(rt[i]);
n1 = rt[i]; pop(rt[i]);
rt[i] = merge(rt[i], new node(n1->pos + len[i]));
rt[i] = merge(rt[i], new node(n2->pos + len[i]));
rt[fa[i]] = merge(rt[fa[i]], rt[i]);
}
while (d[1]--) pop(rt[1]);
int ptot = 0;
while (rt[1]) {
pp[++ptot] = rt[1]->pos;
pop(rt[1]);
}
ll prepos = 0;
while (ptot) {
if (pp[ptot] != prepos) {
sum -= (pp[ptot] - prepos) * ptot;
prepos = pp[ptot];
}
--ptot;
}
printf("%lld\n", sum);
return 0;
}