LOJ #6733. 人造情感
先考虑如何求解 \(W(S)\)。设 \(f_u\) 为考虑子树 \(u\) 内的路径集合的 \(W\) 值,则有转移
\[f_u=\max\left\{\sum_{v\in \operatorname{ch}_u}f_v\right\}\cup\left\{w+\sum_{v\notin \operatorname{path}(x,y)\and\operatorname{fa}_v\in \operatorname{path}(x,y)}f_v\right\} \]其中 \((x,y,w)\) 在给定的路经集合 \(U\) 中,且满足条件 \(\operatorname{lca}(x,y)=u\)。直接做显然就暴毙了,可以思考转移的性质。
观察发现这个转移涉及到“路径下方所挂的点”的求和,于是我们可以想到树上差分。令 \(f'_u=f_u-\sum\limits_{v\in\operatorname{ch}_u}f_v\),在上式中用 \(f'\) 替代 \(f\),于是转移就变成了:
\[f'_u=\max\{0\}\cup\left\{w-\sum_{v\in \operatorname{path}(x,y)\and v\neq u}f'_v\right\} \]其中 \((x,y,w)\) 在给定的路经集合 \(U\) 中,且满足条件 \(\operatorname{lca}(x,y)=u\)。
于是现在只需实现单点加链求和就行了,可以用树状数组维护 \(\operatorname{dfn}\) 序,时间复杂度为 \(\mathcal O(n\log n)\)。
接下来考虑如何计算 \(f(x,y)\)。注意到 \(f(x,y)=f_\mathrm{root}-h_{\operatorname{lca}(x,y)}-\sum\limits_{u\notin\operatorname{path}(x,y)\and \operatorname{fa}_u\in\operatorname{path}(x,y)}f_{u}\),其中 \(h_u\) 为考虑子树 \(u\) 外的路径集合的 \(W\) 值。后面涉及到“路径下方所挂的点”的求和可以直接用树上差分消掉,有 \(f(x,y)=f_{\mathrm{root}}+\sum\limits_{u\in\operatorname{path}(x,y)}f'_u-f_{\operatorname{lca}(u,v)}-h_{\operatorname{lca}(u,v)}\)。前面有关 \(f,f'\) 的值是很容易就能求出来的,关键就是如何计算 \(h_{u}\)。
类似于处理 \(f\) 的方法,我们设 \(h'_u=h_u-h_{\operatorname{fa}_u}-\sum\limits_{v\in\operatorname{ch}_{\operatorname{fa}_u}\and v\neq u}f_{v}\),那么 \(h'\) 有和 \(f'\) 类似的转移式。我们考虑枚举所有经过 \(\operatorname{fa}_u\) 但不经过 \(u\) 的路径 \((x,y,w)\),设 \(z=\operatorname{lca}(x,y)\)。对于当前枚举的 \(u\),设 \(g'_v=\begin{cases}f'_v,\qquad v不是u的祖先\\h'_v,\qquad v是u的祖先\end{cases}\),容易发现路径 \((x,y,w)\) 对 \(h'_u\) 的贡献为 \(w-\sum\limits_{v\in \operatorname{path}(x,y)\and v\neq z}g'_v\)。于是我们就得到了一个朴素的做法:先枚举结点 \(u\),然后枚举路径集合 \(U\) 中经过 \(u\) 的路径 \((x,y,w)\),在更新 \(g'\) 的同时计算贡献。该做法的时间复杂度为 \(\mathcal O(n^2)\)。
考虑优化。注意到 \(U\) 中的路径 \((x,y,w)\) 只会对其下方挂着的所有点造成贡献;更进一步,其对同一个结点下挂着的所有儿子的贡献是一样的。设结点 \(u\) 在路径 \((x,y)\) 上,结点 \(v\) 为结点 \(u\) 的一个儿子,且满足 \(v\) 不在路径 \((x,y)\) 上。设 \(z=\operatorname{lca}(x,y)\),那么该路径对 \(v\) 的贡献为
\[\begin{cases}w-(F_x-F_z)-(F_y-F_z),\qquad\qquad\qquad\qquad\! &(u=z)\\w-(F_x-F_u)-(F_y-F_z)-(H_u-H_z),&(u\in\operatorname{path}(x,z)\and u\neq z)\\w-(F_x-F_z)-(F_y-F_u)-(H_u-H_z),&(u\in\operatorname{path}(y,z)\and u\neq z)\end{cases} \]其中 \(F,H\) 分别为 \(f',h'\) 在树上的祖先和,即有 \(F_x=\sum\limits_{y\in\operatorname{path}(\mathrm{root},x)}f'_y\)。上式中的第二类和第三类可以合在一起讨论,因此只需要讨论 \(u=z\) 和 \(u\neq z\) 的情况。
- \(u=z\) 此时路径 \((x,y,w)\) 对 \(z\) 的子结点 \(v\) 有贡献当且仅当 \(v\) 不在路径 \((x,y)\) 上。将 \(z\) 的所有子结点看成一个序列,那么该路径在序列上的影响可分为 \(1/2/3\) 个区间。子结点序列上需要实现的操作是区间取 \(\max\) 单点查询。因此只需要使用线段树维护 \(z\) 的子结点序列即可。
- \(u\neq z\) 不妨设 \(u\in\operatorname{path}(x,z)\)。此时路径 \((x,y,w)\) 对 \(u\) 的子结点 \(v\) 有贡献当且仅当 \(v\) 不在路径 \((x,y)\) 上。考虑枚举结点 \(u\) 时计算 \(h'_v\),那么对 \(h'_v\) 有贡献的路径 \((x,y,w)\) 一定存在端点 \(x\) 在 \(u\) 子树中而不在 \(v\) 子树中。我们可以在结点 \(x,y\) 上存下路径 \((x,y,w)\) 的贡献,查询时转到 \(\operatorname{dfn}\) 序上就是一个区间查询。因此我们可以用线段树维护 \(\operatorname{dfn}\) 序。
时间复杂度为 \(\mathcal O((n+m)\log n)\)。
参考代码
#include <bits/stdc++.h>
using namespace std;
template<typename _Tp> _Tp &min_eq(_Tp &x, const _Tp &y) { return x = min(x, y); }
template<typename _Tp> _Tp &max_eq(_Tp &x, const _Tp &y) { return x = max(x, y); }
static constexpr int mod = 998244353;
static constexpr int Maxn = 3e5 + 5;
static constexpr int64_t inf = 0x3f3f3f3f3f3f3f3f;
int n, m;
int64_t ans;
vector<int> g[Maxn];
namespace hld {
int par[Maxn], sz[Maxn], son[Maxn], dep[Maxn];
int top[Maxn], dfn[Maxn], idfn[Maxn], ed[Maxn], dn;
void predfs1(int u, int fa, int depth) {
par[u] = fa, dep[u] = depth; sz[u] = 1, son[u] = 0;
for (const int &v: g[u]) if (v != par[u]) {
predfs1(v, u, depth + 1), sz[u] += sz[v];
if (son[u] == 0 || sz[v] > sz[son[u]]) son[u] = v;
}
} // hld::predfs1
void predfs2(int u, int topv) {
top[u] = topv, idfn[dfn[u] = ++dn] = u;
if (son[u] != 0) predfs2(son[u], topv);
for (const int &v: g[u]) if (v != par[u])
if (v != son[u]) predfs2(v, v);
} // hld::predfs2
inline int get_lca(int u, int v) {
for (; top[u] != top[v]; v = par[top[v]])
if (dep[top[u]] > dep[top[v]]) swap(u, v);
return dep[u] < dep[v] ? u : v;
} // hld::get_lca
inline int get_anc(int u, int k) {
if (k < 0 || k >= dep[u]) return 0;
for (; dep[u] - dep[par[top[u]]] <= k; u = par[top[u]])
k -= (dep[u] - dep[par[top[u]]]);
return idfn[dfn[u] - k];
} // hld::get_anc
inline void initialize(int root) {
dn = 0, predfs1(root, 0, 1), predfs2(root, root);
for (int i = 1; i <= n; ++i) ed[i] = dfn[i] + sz[i] - 1;
} // hld::initialize
} // namespace hld
using namespace hld;
namespace fen {
int64_t b[Maxn];
inline void clr(void) { memset(b, 0, sizeof(b)); }
inline void upd(int x, int64_t v) { for (; x <= n; x += x & -x) b[x] += v; }
inline int64_t ask(int x) { int64_t r = 0; for (; x; x -= x & -x) r += b[x]; return r; }
} // namespace fen
struct path { int x, y; int64_t w; } pa[Maxn];
vector<path> plca[Maxn];
int64_t f[Maxn], F[Maxn], f1[Maxn];
void dfs1(int u, int fa) {
for (const int &v: g[u]) if (v != fa) dfs1(v, u);
for (const auto &[x, y, w]: plca[u])
max_eq(f[u], w - fen::ask(dfn[x]) - fen::ask(dfn[y]));
fen::upd(dfn[u], f[u]), fen::upd(ed[u] + 1, -f[u]);
} // dfs1
void dfs11(int u, int fa) {
F[u] = f[u] + F[fa], f1[u] = f[u];
for (const int &v: g[u]) if (v != fa)
dfs11(v, u), f1[u] += f1[v];
} // dfs11
int64_t h[Maxn], H[Maxn], h1[Maxn];
namespace sgt1 {
int64_t tr[Maxn * 4];
void update(int p, int l, int r, int x, int64_t v) {
max_eq(tr[p], v);
if (l == r) return ; int mid = (l + r) / 2;
if (x <= mid) update(p * 2 + 0, l, mid, x, v);
else update(p * 2 + 1, mid + 1, r, x, v);
} // sgt1::update
int64_t query(int p, int l, int r, int L, int R) {
if (L > r || l > R) return -inf;
if (L <= l && r <= R) return tr[p];
int mid = (l + r) / 2;
return max(query(p * 2 + 0, l, mid, L, R), query(p * 2 + 1, mid + 1, r, L, R));
} // sgt1::query
inline void upd(int x, int64_t v) { return update(1, 1, n, x, v); }
inline int64_t ask(int l, int r) { return query(1, 1, n, l, r); }
} // namespace sgt1
namespace sgt2 {
int N; int64_t tr[Maxn * 4];
void build(int n) { N = n, memset(tr, -63, (n + 1) * 4 * sizeof(*tr)); }
void update(int p, int l, int r, int L, int R, int64_t v) {
if (L > r || l > R) return ;
if (L <= l && r <= R) return max_eq(tr[p], v), void();
int mid = (l + r) / 2;
update(p * 2 + 0, l, mid, L, R, v);
update(p * 2 + 1, mid + 1, r, L, R, v);
} // sgt2::update
int64_t query(int p, int l, int r, int x) {
if (l == r) return tr[p];
int mid = (l + r) / 2; int64_t t = tr[p];
if (x <= mid) max_eq(t, query(p * 2 + 0, l, mid, x));
else max_eq(t, query(p * 2 + 1, mid + 1, r, x));
return t;
} // sgt2::query
inline void upd(int l, int r, int64_t v) { return update(1, 1, N, l, r, v); }
inline int64_t ask(int x) { return query(1, 1, N, x); }
} // namespace sgt2
void dfs2(int u, int fa) {
H[u] = H[fa] + h[u];
static int label[Maxn]; int N = 0;
for (const int &v: g[u]) if (v != fa) label[v] = ++N;
if (N != 0) {
for (const int &v: g[u]) if (v != fa)
max_eq(h[v], max(sgt1::ask(dfn[u], dfn[v] - 1), sgt1::ask(ed[v] + 1, ed[u])) + F[u] - H[u]);
sgt2::build(N);
for (auto [x, y, w]: plca[u]) {
if (dep[x] > dep[y]) swap(x, y);
int xk = get_anc(x, dep[x] - dep[u] - 1);
int yk = get_anc(y, dep[y] - dep[u] - 1);
if (xk == 0 && yk == 0) {
sgt2::upd(1, N, w);
} else if (xk == 0) {
int64_t v = w - F[y] + F[u];
if (1 < label[yk]) sgt2::upd(1, label[yk] - 1, v);
if (label[yk] < N) sgt2::upd(label[yk] + 1, N, v);
} else {
int64_t v = w - F[x] + F[u] - F[y] + F[u];
if (label[xk] > label[yk]) swap(x, y), swap(xk, yk);
if (1 < label[xk]) sgt2::upd(1, label[xk] - 1, v);
if (label[yk] < N) sgt2::upd(label[yk] + 1, N, v);
if (label[xk] + 1 <= label[yk] - 1) sgt2::upd(label[xk] + 1, label[yk] - 1, v);
}
}
for (const int &v: g[u]) if (v != fa)
max_eq(h[v], sgt2::ask(label[v]));
}
for (auto [x, y, w]: plca[u]) {
if (dep[x] > dep[y]) swap(x, y);
int xk = get_anc(x, dep[x] - dep[u] - 1);
int yk = get_anc(y, dep[y] - dep[u] - 1);
if (xk == 0 && yk == 0) {
} else if (xk == 0) {
int64_t v = w - F[y] + H[u];
sgt1::upd(dfn[y], v);
} else {
int64_t v = w - F[x] - F[y] + H[u] + F[u];
sgt1::upd(dfn[x], v);
sgt1::upd(dfn[y], v);
}
}
for (const int &v: g[u]) if (v != fa) dfs2(v, u);
} // dfs2
void dfs3(int u, int fa) {
for (const int &v: g[u]) if (v != fa)
h1[v] = h1[u] + f1[u] - f1[v] - f[u] + h[v], dfs3(v, u);
} // dfs3
int64_t sf[Maxn];
void dfs4(int u, int fa) {
for (const int &v: g[u]) if (v != fa) dfs4(v, u), (sf[u] += sf[v]) %= mod;
int64_t z = (int64_t)sz[u] * sz[u];
for (const int &v: g[u]) if (v != fa) z -= (int64_t)sz[v] * sz[v];
((ans -= (z % mod) * ((f1[u] + h1[u]) % mod) % mod) += mod) %= mod;
for (const int &v: g[u]) if (v != fa) (ans += 2 * sf[v] * (sz[u] - sz[v]) % mod) %= mod;
(ans += (f[u] % mod) * (z % mod) % mod) %= mod;
(sf[u] += sz[u] * f[u] % mod) %= mod;
} // dfs4
int main(void) {
scanf("%d%d", &n, &m);
for (int i = 2; i <= n; ++i) {
int u, v;
scanf("%d%d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
}
hld::initialize(1);
for (int i = 1; i <= m; ++i) {
scanf("%d%d%lld", &pa[i].x, &pa[i].y, &pa[i].w);
int z = get_lca(pa[i].x, pa[i].y);
plca[z].push_back(pa[i]);
}
fen::clr(), dfs1(1, 0), dfs11(1, 0);
memset(sgt1::tr, -63, sizeof(sgt1::tr));
dfs2(1, 0), dfs3(1, 0);
ans = 0, dfs4(1, 0);
(ans += (int64_t)n * n % mod * (f1[1] % mod) % mod) %= mod;
printf("%lld\n", ans);
exit(EXIT_SUCCESS);
} // main