LOJ #3166. 「CEOI2019」魔法树

LOJ #3166. 「CEOI2019」魔法树

首先可以列出一个 \(\text{dp}\) 状态: 设 \(f_{u,t}\) 表示恰好在 \(t\) 时刻剪断节点 \(u\) 与其父亲的边可获得的最大收益.

这显然是一个树上背包. 先不考虑节点 \(u\) 自身的贡献, 在合并两个儿子 \(u,v\) 时的状态转移方程应为:

\[f'_{u,t}=\max\{f_{u,t}+\max_{1\le i\le t}{f_{v,i}},f_{v,t}+\max_{1\le i\le t}{f_{u,i}}\} \]

然后再考虑节点 \(u\) 自身的贡献, \(f'_{u,d_u}=\max\limits_{1\le i\le d_u}\{f_{u,i}\}+w_u\).

但这样的 \(\text{dp}\) 的时间空间复杂度均为 \(\mathrm O(n^2)\). 考虑优化.

  1. 线段树合并

发现 \(\text{dp}\) 背包部分的转移是一个前缀 \(\max\), 于是可以使用线段树合并进行优化.

  • \(\textbf{trick}\) 线段树合并优化 \(\text{dp}\) 前/后缀转移.

朴素的线段树合并是相同的位置进行合并, 在这个过程中是从左到右进行合并的. 于是在合并的同时, 可以记录两颗线段树目前合并位置的前缀值, 这样就完成了前缀转移. 同理, 线段树合并也可以优化后缀转移.

时间复杂度: \(\mathrm O(n\log n)\), 空间复杂度: \(\mathrm {O(n\log n)}\).

参考代码
#include "bits/stdc++.h"
using namespace std;
static constexpr int Maxn = 1e5 + 5;
int n, m, k;
vector<int> g[Maxn];
int d[Maxn];
int64_t w[Maxn];
struct treedot {
  int ls, rs;
  int64_t val;
  int64_t lz;
} tr[Maxn << 5];
int tot, root[Maxn];
void pushup(int p, int l, int r) {
  tr[p].val = max(tr[tr[p].ls].val, tr[tr[p].rs].val);
} // pushup
void apply(int p, int l, int r, int64_t v) {
  if (!p) return ;
  tr[p].val += v, tr[p].lz += v;
} // apply
void pushdown(int p, int l, int r) {
  if (!p) return ;
  int mid = (l + r) >> 1;
  if (tr[p].lz != 0) {
    apply(tr[p].ls, l, mid, tr[p].lz);
    apply(tr[p].rs, mid + 1, r, tr[p].lz);
    tr[p].lz = 0;
  }
} // pushdown
void modify(int &p, int l, int r, int x, int64_t v) {
  if (!p) p = ++tot;
  if (l == r) {
    tr[p].val = max(tr[p].val, v);
  } else {
    int mid = (l + r) >> 1;
    pushdown(p, l, r);
    if (x <= mid) modify(tr[p].ls, l, mid, x, v);
    else modify(tr[p].rs, mid + 1, r, x, v);
    pushup(p, l, r);
  }
} // modify
int64_t query(int p, int l, int r, int L, int R) {
  if (!p || L > r || l > R) return 0;
  if (L <= l && r <= R) return tr[p].val;
  int mid = (l + r) >> 1;
  pushdown(p, l, r);
  return max(query(tr[p].ls, l, mid, L, R), query(tr[p].rs, mid + 1, r, L, R));
} // query
int join(int u, int v, int l, int r, int64_t pu, int64_t pv) {
  if (!u && !v) return 0;
  if (!u) { apply(v, l, r, pv); return v; }
  if (!v) { apply(u, l, r, pu); return u; }
  if (l == r) {
    pu = max(pu, tr[v].val);
    pv = max(pv, tr[u].val);
    tr[u].val = max(tr[u].val + pu, tr[v].val + pv);
    return u;
  }
  int mid = (l + r) >> 1;
  pushdown(u, l, r); pushdown(v, l, r);
  int64_t lu_val = tr[tr[u].ls].val, lv_val = tr[tr[v].ls].val;
  tr[u].ls = join(tr[u].ls, tr[v].ls, l, mid, pu, pv);
  tr[u].rs = join(tr[u].rs, tr[v].rs, mid + 1, r, max(pu, lv_val), max(pv, lu_val));
  pushup(u, l, r);
  return u;
} // join
void dfs(int u, int fa) {
  for (const int &v: g[u]) if (v != fa) {
    dfs(v, u);
    root[u] = join(root[u], root[v], 1, k, 0LL, 0LL);
  }
  if (d[u] != 0) {
    int64_t W = query(root[u], 1, k, 1, d[u]);
    modify(root[u], 1, k, d[u], W + w[u]);
  }
} // dfs
int main(void) {
  scanf("%d%d%d", &n, &m, &k);
  for (int i = 2, pi; i <= n; ++i) {
    scanf("%d", &pi);
    g[pi].push_back(i);
    g[i].push_back(pi);
  }
  for (int i = 1, v; i <= m; ++i) {
    scanf("%d", &v);
    scanf("%d%lld", &d[v], &w[v]);
  }
  dfs(1, 0);
  printf("%lld\n", tr[root[1]].val);
  exit(EXIT_SUCCESS);
} // main
  1. 树上启发式合并

容易发现, 节点 \(u\) 有贡献的 \(\text{dp}\) 值最多就 \(\text{sz}_u\) 个, 其中 \(\text{sz}_u\) 表示子树 \(u\) 的大小. 于是可以想到树上启发式合并.

注意到直接合并 \(f\) 的值复杂度会爆炸. 考虑合并 \(f\) 的前缀值.

设 \(g_{u,t}=\max\limits_{1\le i\le t}f_{u,i}\) 易知 \(g_u\) 是单调不降的. 将上述的转移方程改写, 有

\[g_{u,t}=\max\left\{[d_u\le t](g_{u,d_u}+w_u),\,\sum\limits_{v\in \text{son}_u}g_{v,t}\right\} \]

发现 \(g_{u}\) 中不同的值最多有 \(\text{sz}_u\) 个, 于是可以用 map 维护 \(g_u\) 不同值域的最小下标和值.

前面那一部分可以直接暴力更新, 但后面那一部分不太好直接启发式合并, 于是改用 map 维护 \(g_u\) 差分数组. 这样就可以直接启发式合并了.

时间复杂度: \(\mathrm O(n\log^2n)\), 空间复杂度: \(O(n)\).

参考代码
#include "bits/stdc++.h"
using namespace std;
static constexpr int Maxn = 1e5 + 5;
int n, m, k;
vector<int> g[Maxn];
int d[Maxn];
int64_t w[Maxn];
int sz[Maxn], son[Maxn];
map<int, int64_t> s[Maxn];
void join(map<int, int64_t> &x, map<int, int64_t> &y) {
  for (const auto &[t, v]: y) x[t] += v;
  y.clear();
} // join
void dfs(int u, int fa) {
  sz[u] = 1; son[u] = 0;
  for (const int &v: g[u]) if (v != fa) {
    dfs(v, u), sz[u] += sz[v];
    if (!son[u] || sz[son[u]] < sz[v]) son[u] = v;
  }
  if (son[u]) join(s[son[u]], s[u]), s[u].swap(s[son[u]]);
  for (const int &v: g[u]) if (v != fa) {
    if (v != son[u]) {
      join(s[u], s[v]);
    }
  }
  if (d[u] != 0) {
    s[u][d[u]] += w[u];
    int W = w[u];
    for (auto it = next(s[u].find(d[u])); it != s[u].end(); ) {
      if (it->second > W) {
        it->second -= W;
        break;
      }
      W -= it->second;
      s[u].erase(it++);
    }
  }
} // dfs
int main(void) {
  scanf("%d%d%d", &n, &m, &k);
  for (int i = 2, pi; i <= n; ++i) {
    scanf("%d", &pi);
    g[pi].push_back(i);
    g[i].push_back(pi);
  }
  for (int i = 1, v; i <= m; ++i) {
    scanf("%d", &v);
    scanf("%d%lld", &d[v], &w[v]);
  }
  dfs(1, 0);
  int64_t ans = 0;
  for (const auto &[t, v]: s[1]) ans += v;
  printf("%lld\n", ans);
  exit(EXIT_SUCCESS);
} // main
上一篇:国外最有价值的教育科技公司 Byju‘s


下一篇:【算法提高——第四讲】高级数据结构