LOJ #3166. 「CEOI2019」魔法树
首先可以列出一个 \(\text{dp}\) 状态: 设 \(f_{u,t}\) 表示恰好在 \(t\) 时刻剪断节点 \(u\) 与其父亲的边可获得的最大收益.
这显然是一个树上背包. 先不考虑节点 \(u\) 自身的贡献, 在合并两个儿子 \(u,v\) 时的状态转移方程应为:
\[f'_{u,t}=\max\{f_{u,t}+\max_{1\le i\le t}{f_{v,i}},f_{v,t}+\max_{1\le i\le t}{f_{u,i}}\} \]然后再考虑节点 \(u\) 自身的贡献, \(f'_{u,d_u}=\max\limits_{1\le i\le d_u}\{f_{u,i}\}+w_u\).
但这样的 \(\text{dp}\) 的时间空间复杂度均为 \(\mathrm O(n^2)\). 考虑优化.
- 线段树合并
发现 \(\text{dp}\) 背包部分的转移是一个前缀 \(\max\), 于是可以使用线段树合并进行优化.
- \(\textbf{trick}\) 线段树合并优化 \(\text{dp}\) 前/后缀转移.
朴素的线段树合并是相同的位置进行合并, 在这个过程中是从左到右进行合并的. 于是在合并的同时, 可以记录两颗线段树目前合并位置的前缀值, 这样就完成了前缀转移. 同理, 线段树合并也可以优化后缀转移.
时间复杂度: \(\mathrm O(n\log n)\), 空间复杂度: \(\mathrm {O(n\log n)}\).
参考代码
#include "bits/stdc++.h"
using namespace std;
static constexpr int Maxn = 1e5 + 5;
int n, m, k;
vector<int> g[Maxn];
int d[Maxn];
int64_t w[Maxn];
struct treedot {
int ls, rs;
int64_t val;
int64_t lz;
} tr[Maxn << 5];
int tot, root[Maxn];
void pushup(int p, int l, int r) {
tr[p].val = max(tr[tr[p].ls].val, tr[tr[p].rs].val);
} // pushup
void apply(int p, int l, int r, int64_t v) {
if (!p) return ;
tr[p].val += v, tr[p].lz += v;
} // apply
void pushdown(int p, int l, int r) {
if (!p) return ;
int mid = (l + r) >> 1;
if (tr[p].lz != 0) {
apply(tr[p].ls, l, mid, tr[p].lz);
apply(tr[p].rs, mid + 1, r, tr[p].lz);
tr[p].lz = 0;
}
} // pushdown
void modify(int &p, int l, int r, int x, int64_t v) {
if (!p) p = ++tot;
if (l == r) {
tr[p].val = max(tr[p].val, v);
} else {
int mid = (l + r) >> 1;
pushdown(p, l, r);
if (x <= mid) modify(tr[p].ls, l, mid, x, v);
else modify(tr[p].rs, mid + 1, r, x, v);
pushup(p, l, r);
}
} // modify
int64_t query(int p, int l, int r, int L, int R) {
if (!p || L > r || l > R) return 0;
if (L <= l && r <= R) return tr[p].val;
int mid = (l + r) >> 1;
pushdown(p, l, r);
return max(query(tr[p].ls, l, mid, L, R), query(tr[p].rs, mid + 1, r, L, R));
} // query
int join(int u, int v, int l, int r, int64_t pu, int64_t pv) {
if (!u && !v) return 0;
if (!u) { apply(v, l, r, pv); return v; }
if (!v) { apply(u, l, r, pu); return u; }
if (l == r) {
pu = max(pu, tr[v].val);
pv = max(pv, tr[u].val);
tr[u].val = max(tr[u].val + pu, tr[v].val + pv);
return u;
}
int mid = (l + r) >> 1;
pushdown(u, l, r); pushdown(v, l, r);
int64_t lu_val = tr[tr[u].ls].val, lv_val = tr[tr[v].ls].val;
tr[u].ls = join(tr[u].ls, tr[v].ls, l, mid, pu, pv);
tr[u].rs = join(tr[u].rs, tr[v].rs, mid + 1, r, max(pu, lv_val), max(pv, lu_val));
pushup(u, l, r);
return u;
} // join
void dfs(int u, int fa) {
for (const int &v: g[u]) if (v != fa) {
dfs(v, u);
root[u] = join(root[u], root[v], 1, k, 0LL, 0LL);
}
if (d[u] != 0) {
int64_t W = query(root[u], 1, k, 1, d[u]);
modify(root[u], 1, k, d[u], W + w[u]);
}
} // dfs
int main(void) {
scanf("%d%d%d", &n, &m, &k);
for (int i = 2, pi; i <= n; ++i) {
scanf("%d", &pi);
g[pi].push_back(i);
g[i].push_back(pi);
}
for (int i = 1, v; i <= m; ++i) {
scanf("%d", &v);
scanf("%d%lld", &d[v], &w[v]);
}
dfs(1, 0);
printf("%lld\n", tr[root[1]].val);
exit(EXIT_SUCCESS);
} // main
- 树上启发式合并
容易发现, 节点 \(u\) 有贡献的 \(\text{dp}\) 值最多就 \(\text{sz}_u\) 个, 其中 \(\text{sz}_u\) 表示子树 \(u\) 的大小. 于是可以想到树上启发式合并.
注意到直接合并 \(f\) 的值复杂度会爆炸. 考虑合并 \(f\) 的前缀值.
设 \(g_{u,t}=\max\limits_{1\le i\le t}f_{u,i}\) 易知 \(g_u\) 是单调不降的. 将上述的转移方程改写, 有
\[g_{u,t}=\max\left\{[d_u\le t](g_{u,d_u}+w_u),\,\sum\limits_{v\in \text{son}_u}g_{v,t}\right\} \]发现 \(g_{u}\) 中不同的值最多有 \(\text{sz}_u\) 个, 于是可以用 map
维护 \(g_u\) 不同值域的最小下标和值.
前面那一部分可以直接暴力更新, 但后面那一部分不太好直接启发式合并, 于是改用 map
维护 \(g_u\) 差分数组. 这样就可以直接启发式合并了.
时间复杂度: \(\mathrm O(n\log^2n)\), 空间复杂度: \(O(n)\).
参考代码
#include "bits/stdc++.h"
using namespace std;
static constexpr int Maxn = 1e5 + 5;
int n, m, k;
vector<int> g[Maxn];
int d[Maxn];
int64_t w[Maxn];
int sz[Maxn], son[Maxn];
map<int, int64_t> s[Maxn];
void join(map<int, int64_t> &x, map<int, int64_t> &y) {
for (const auto &[t, v]: y) x[t] += v;
y.clear();
} // join
void dfs(int u, int fa) {
sz[u] = 1; son[u] = 0;
for (const int &v: g[u]) if (v != fa) {
dfs(v, u), sz[u] += sz[v];
if (!son[u] || sz[son[u]] < sz[v]) son[u] = v;
}
if (son[u]) join(s[son[u]], s[u]), s[u].swap(s[son[u]]);
for (const int &v: g[u]) if (v != fa) {
if (v != son[u]) {
join(s[u], s[v]);
}
}
if (d[u] != 0) {
s[u][d[u]] += w[u];
int W = w[u];
for (auto it = next(s[u].find(d[u])); it != s[u].end(); ) {
if (it->second > W) {
it->second -= W;
break;
}
W -= it->second;
s[u].erase(it++);
}
}
} // dfs
int main(void) {
scanf("%d%d%d", &n, &m, &k);
for (int i = 2, pi; i <= n; ++i) {
scanf("%d", &pi);
g[pi].push_back(i);
g[i].push_back(pi);
}
for (int i = 1, v; i <= m; ++i) {
scanf("%d", &v);
scanf("%d%lld", &d[v], &w[v]);
}
dfs(1, 0);
int64_t ans = 0;
for (const auto &[t, v]: s[1]) ans += v;
printf("%lld\n", ans);
exit(EXIT_SUCCESS);
} // main