做法一:
线段树合并.
令 $\mathrm{dp[x][i]}$ 表示以 $\mathrm{x}$ 节点为根的子树全部被覆盖且延伸到了深度为 $\mathrm{i}$ 的祖先.
考虑 $\mathrm{x},\mathrm{y}$ 两个子树如何合并:
有 $\mathrm{dp‘[x][i]=dp[x][i]+}$ $\mathrm{min}${$\mathrm{dp[y]}$}, $\mathrm{dp[y]}$ 同理.
这个可以用线段树来整体维护,每次就是将 $\mathrm{x}$ 与 $\mathrm{y}$ 子树进行区间加法.
更新以 $\mathrm{x}$ 点为终点的链是一个单点更新,也可以用线段树来维护.
两个子树合并用线段树合并即可, 合并区间时将区间最小值取 $\min$.
#include <cstdio> #include <cstring> #include <vector> #include <algorithm> #define N 300009 #define ll long long #define pb push_back #define setIO(s) freopen(s".in","r",stdin) using namespace std; const ll in2 = (ll)1e16; const ll inf = (ll)1e17; int n,m,tot, rt[N], dep[N]; vector<int>G[N]; struct Line { int x, c; Line(int x=0,int c=0):x(x),c(c){} }; vector<Line>q[N]; struct data { ll mi; int ls, rs; data() { mi = inf, ls = rs = 0; } }s[N * 30]; ll add[N * 30]; void mark(int x, ll v) { if(!x) return ; add[x] = min(inf, add[x] + v); s[x].mi = min(inf, s[x].mi + v); } void pushdown(int x) { if(add[x]) { mark(s[x].ls, add[x]); mark(s[x].rs, add[x]); } add[x] = 0; } void pushup(int x) { s[x].mi = min(s[s[x].ls].mi, s[s[x].rs].mi); } void update(int &x, int l, int r, int p, ll v) { if(!x) x = ++ tot, s[x].mi = inf; s[x].mi = min(s[x].mi, v); if(l == r) return ; pushdown(x); int mid = ( l + r ) >> 1; if(p <= mid) { update(s[x].ls, l, mid, p, v); } else { update(s[x].rs, mid + 1, r, p, v); } pushup(x); } int merge(int l, int r, int x, int y) { if(!x || !y) return x + y; s[x].mi = min(s[x].mi, s[y].mi); if(l == r) return x; int mid = (l + r) >> 1; pushdown(x); pushdown(y); s[x].ls = merge(l, mid, s[x].ls, s[y].ls); s[x].rs = merge(mid + 1, r, s[x].rs, s[y].rs); return x; } ll query(int l, int r, int x, int L, int R) { if(!x || l > r || L > R) return inf; if(l >= L && r <= R) return s[x].mi; pushdown(x); int mid = (l + r) >> 1; if(L <= mid && R > mid) return min(query(l, mid, s[x].ls, L, R), query(mid + 1, r, s[x].rs, L, R)); else if(L <= mid) return query(l, mid, s[x].ls, L, R); else return query(mid + 1, r, s[x].rs, L, R); } void dfs(int x, int ff) { dep[x] = dep[ff] + 1; int son = 0; for(int i=0;i<G[x].size();++i) { int y=G[x][i]; if(y==ff) continue; dfs(y, x); if(!son) { ++son, rt[x]=rt[y]; continue; } ++ son ; // 合并 v ll px = query(1, n, rt[x], 1, dep[x]); ll py = query(1, n, rt[y], 1, dep[x]); mark(rt[x], py); mark(rt[y], px); rt[x] = merge(1, n, rt[x], rt[y]); } if(son == 0) { update(rt[x], 1, n, dep[x], 0ll); } // 合并完了 x for(int i = 0; i < q[x].size() ; ++ i) { int tp = q[x][i].x; ll v = (ll)q[x][i].c; ll px = query(1, n, rt[x], 1, dep[x]); if(v + px < inf) { update(rt[x],1, n, dep[tp], v + px); } } } void ini(int x, int ff) { dep[x] = dep[ff] + 1; for(int i = 0 ; i < G[x].size(); ++ i) if(G[x][i] != ff) ini(G[x][i], x); } int main() { // setIO("input"); s[0].mi = inf ; scanf("%d%d",&n,&m); for(int i=1;i<n;++i) { int x, y; scanf("%d%d",&x,&y); G[x].pb(y); G[y].pb(x); } ini(1, 0); for(int i=1;i<=m;++i) { int x,y,z; scanf("%d%d%d",&x,&y,&z); if(x == y) continue; if(dep[y] > dep[x]) swap(x, y); q[x].pb(Line(y, z)); } dfs(1, 0); ll ans = query(1, n, rt[1], 1, 1); if(ans > in2) printf("-1\n"); else printf("%lld\n", ans); return 0; }