树状数组 + dp
设$f_i$表示以$i$为根的子树中的能选取的最大和,$sum_x$表示$\sum_{f_y}$ ($y$是$x$的一个儿子),这样子我们把所有给出的链按照两点的$lca$分组,对于每一个点$x$,$sum_x$显然是一个$f_x$的一个备选答案,而当有树链的$lca$正好是$x$时,我们发现$sum_x + w + \sum_{sum_t} - \sum_{f_t}$($w$代表这条树链能产生的价值,$t$是树链上的一个点)。
那么我们只要能快速计算出这两个$\sum$就可以转移了,其实用两个树状数组维护$dfs$序即可。
具体可以参照代码。
时间复杂度$O(Tnlogn)$。
Code:
#include <cstdio>
#include <cstring>
#include <vector>
using namespace std; const int N = 1e5 + ;
const int Lg = ; int testCase, n, m, f[N], sum[N], dfsc, in[N], out[N];
int tot, head[N], fa[N][Lg], dep[N];
vector <int> vec[N]; struct Edge {
int to, nxt;
} e[N << ]; inline void add(int from, int to) {
e[++tot].to = to;
e[tot].nxt = head[from];
head[from] = tot;
} struct Item {
int u, v, lca, val;
} a[N]; inline void read(int &X) {
X = ; char ch = ; int op = ;
for(; ch > '' || ch < ''; ch = getchar())
if(ch == '-') op = -;
for(; ch >= '' && ch <= ''; ch = getchar())
X = (X << ) + (X << ) + ch - ;
X *= op;
} struct Bit {
int s[N << ]; #define lowbit(p) (p & (-p)) inline void clear() {
memset(s, , sizeof(s));
} inline void modify(int p, int val) {
for(; p <= * n; p += lowbit(p))
s[p] += val;
} inline int query(int p) {
int res = ;
for(; p > ; p -= lowbit(p))
res += s[p];
return res;
} } bit1, bit2; inline void swap(int &x, int &y) {
int t = x; x = y; y = t;
} inline void chkMax(int &x, int y) {
if(y > x) x = y;
} void dfs(int x, int fat, int depth) {
fa[x][] = fat, dep[x] = depth;
in[x] = ++dfsc;
for(int i = ; i <= ; i++)
fa[x][i] = fa[fa[x][i - ]][i - ]; for(int i = head[x]; i; i = e[i].nxt) {
int y = e[i].to;
if(y == fat) continue;
dfs(y, x, depth + );
}
out[x] = ++dfsc;
} inline int getLca(int x, int y) {
if(dep[x] < dep[y]) swap(x, y);
for(int i = ; i >= ; i--)
if(dep[fa[x][i]] >= dep[y])
x = fa[x][i];
if(x == y) return x;
for(int i = ; i >= ; i--)
if(fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
return fa[x][];
} void solve(int x, int fat) {
for(int i = head[x]; i; i = e[i].nxt) {
int y = e[i].to;
if(y == fat) continue;
solve(y, x); sum[x] += f[y];
} chkMax(f[x], sum[x]);
for(unsigned int i = ; i < vec[x].size(); i++) {
int id = vec[x][i];
int now = bit1.query(in[a[id].v]) + bit1.query(in[a[id].u]) - bit2.query(in[a[id].u]) - bit2.query(in[a[id].v]) + sum[x];
chkMax(f[x], now + a[id].val);
} bit1.modify(in[x], sum[x]), bit1.modify(out[x], -sum[x]);
bit2.modify(in[x], f[x]), bit2.modify(out[x], -f[x]);
} int main() {
for(read(testCase); testCase--; ) {
tot = ; memset(head, , sizeof(head));
read(n), read(m);
for(int x, y, i = ; i < n; i++) {
read(x), read(y);
add(x, y), add(y, x);
} dfsc = ; dfs(, , ); for(int i = ; i <= n; i++) vec[i].clear();
for(int i = ; i <= m; i++) {
read(a[i].u), read(a[i].v), read(a[i].val);
a[i].lca = getLca(a[i].u, a[i].v);
vec[a[i].lca].push_back(i);
} memset(f, , sizeof(f));
memset(sum, , sizeof(sum));
bit1.clear(), bit2.clear();
solve(, ); printf("%d\n", f[]);
} return ;
}