https://www.luogu.com.cn/problem/P4292
感觉长链剖分的难点在于指针的使用
具体实现看代码吧,关于每往上继承一个要加一条边的权值,可以利用差分的思想
code:
#include<bits/stdc++.h>
#define N 500050
#define db double
#define ll long long
using namespace std;
struct edge {
int v, c, nxt;
} e[N << 1];
int p[N], eid;
void init() {
memset(p, -1, sizeof p);
eid = 0;
}
void insert(int u, int v, int c) {
e[eid].v = v;
e[eid].c = c;
e[eid].nxt = p[u];
p[u] = eid ++;
}
int n, L, R;
const db eps = 1e-5;
db ma[N << 2];
const db inf = 1e18;
#define ls (rt << 1)
#define rs (rt << 1 | 1)
void build(int rt, int l, int r) {
ma[rt] = - inf;
if(l == r) return ;
int mid = (l + r) >> 1;
build(ls, l, mid), build(rs, mid + 1, r);
}
void add(int rt, int l, int r, int x, db o) {
ma[rt] = max(ma[rt], o);
if(l == r) return ;
int mid = (l + r) >> 1;
if(x <= mid) add(ls, l, mid, x, o);
else add(rs, mid + 1, r, x, o);
}
db query(int rt, int l, int r, int L, int R) {
if(L <= l && r <= R) return ma[rt];
int mid = (l + r) >> 1; db ret = - inf;
if(L <= mid) ret = query(ls, l, mid, L, R);
if(R > mid) ret = max(ret, query(rs, mid + 1, r, L, R));
return ret;
}
int w[N], len[N];
db sl[N];
void dfs(int u, int fa, db X) {
w[u] = sl[u] = 0;
for(int i = p[u]; i + 1; i = e[i].nxt) {
int v = e[i].v, c = e[i].c;
if(v == fa) continue;
dfs(v, u, X);
if(len[v] >= len[w[u]]) {
w[u] = v;
len[u] = len[v] + 1;
sl[u] = sl[v] + (db)c - X;
}
}
}
int pos[N], id;
db *f[N << 1], B[N << 1], ans;
void solve(int u, int fa, db X) {
f[u][0] = 0 - sl[u];
add(1, 1, n, pos[u], f[u][0]);
if(w[u]) {
f[w[u]] = f[u] + 1;
pos[w[u]] = pos[u] + 1;
solve(w[u], u, X);
}
for(int i = p[u]; i + 1; i = e[i].nxt) {
int v = e[i].v;
if(v == fa || v == w[u]) continue;
db c = e[i].c - X;
pos[v] = id;
f[v] = B + id, id += len[v] + 1;
solve(v, u, X);
for(int j = 0; j <= len[v]; j ++) {
int l = pos[u] + max(0, L - (j + 1));
int r = pos[u] + min(len[u], R - (j + 1));
// printf("%d --> %d %lf %d %d %d %d\n", u, v, c, l, r, max(0, L - (j + 1)), min(len[u], R - (j + 1)));
if(l <= r) {
ans = max(ans, (query(1, 1, n, l, r) + sl[u]) + (f[v][j] + sl[v] + c));
}
}
for(int j = 0; j <= len[v]; j ++) {
if(f[v][j] + sl[v] + c > f[u][j + 1] + sl[u]) {
f[u][j + 1] = (f[v][j] + sl[v]) + c - sl[u];
add(1, 1, n, pos[u] + j + 1, f[u][j + 1]);
}
}
}
int l = pos[u] + L, r = pos[u] + min(R, len[u]);
if(l <= r) ans = max(ans, query(1, 1, n, l, r) + sl[u]);
}
int check(db X) {
build(1, 1, n);
dfs(1, 0, X);
ans = - inf;
pos[1] = id = 1;
f[1] = B + id, id += len[1] + 1;
solve(1, 0, X);
// printf("%lf %lf\n", X, ans);
//for(int i = 1; i <= n; i ++) printf("%lf ", sl[i]); printf("\n");
//for(int i = 1; i <= n; i ++) printf("%d ", w[i]); printf("\n");
return ans > - eps;
}
int main() {
init();
scanf("%d%d%d", &n, &L, &R);
for(int i = 1; i < n; i ++) {
int u, v, c;
scanf("%d%d%d", &u, &v, &c);
insert(u, v, c), insert(v, u, c);
}
check(2.5);
db l = 0, r = 1e6 + 1;
while(l + eps < r) {
db mid = (l + r) / 2.0;
if(check(mid)) l = mid;
else r = mid;
}
printf("%.3lf", l);
return 0;
}