题意:
- 给你一棵 \(n\) 个顶点的树和 \(m\) 条带权值的附加边
- 你要选择一些附加边加入原树中使其成为一个仙人掌(每个点最多属于 \(1\) 个简单环)
- 求你选择的附加边权值之和的最大值。
- \(n \in [3,2 \times 10^5]\),\(m \in [0,2 \times 10^5]\)
很容易发现,如果你选择一条端点为 \(u,v\) 的附加边,那么 \(u\) 到 \(v\) 路径上所有点都会被包含进一个简单环。
那么题目变为:有 \(m\) 条路径,你要选择其中某些路径使得每个点最多被一条路径覆盖。问选择的路径权值之和的最大值。
对于这一类问题,我们可以想到树形 \(dp\)。记 \(dp_u\) 为:选择一些路径,路径上所有点都在 \(u\) 的子树内,并且路径两两不相交,这些路径权值之和的最大值。
转移分两种情况:
- \(u\) 没有被覆盖,那么 \(dp_u=\sum\limits_{fa_v=u}dp_v\)
- \(u\) 被覆盖。那么显然它只能被 \(LCA=u\) 的路径覆盖(因为根据之前的定义,路径上所有节点都必须在 \(u\) 的子树内)。这种情况下就是 \(u\) 的子树挖掉这条路径后剩余部分的答案+这条路径的权值。
是不是听起来有些抽象?不妨画个图理解看:
假设有如下图的树,你选择了由橙色边组成的路径:
我们把这条路径上的点全部去掉看看剩余部分是个什么东西。
凭直觉,这部分的贡献为 \(dp[4]+dp[5]+dp[6]+dp[7]+dp[10]+dp[13]+dp[16]+dp[17]+dp[19]+dp[21]+dp[24]\)
发现了什么了吗?
你可能会说,项数这么多,我啥也没发现。
那我们就把删掉的点补全后看看呗。
通过观察,我们发现:
- \(4,5\) 是 \(3\) 的儿子
- \(6,7\) 是 \(2\) 除了 \(3\) 之外的儿子
- \(10,13\) 是 \(1\) 除了 \(2,15\) 之外的儿子
- \(16,17\) 是 \(15\) 除了 \(18\) 之外的儿子
- \(19\) 是 \(18\) 除了 \(20\) 之外的儿子
- \(21,24\) 是 \(20\) 的儿子。
我们记 \(p_1 \to p_2 \to \dots \to p_k\) 为我们选择的路径,其中 \(p_j=i\),\(p_0=p_{k+1}=0\),那么贡献为:
\[\sum\limits_{i=1}^j\sum\limits_{fa_v=p_i}dp_v \times [v \neq p_{i-1}]+\sum\limits_{fa_v=u}dp_v \times [v \neq p_{i-1} \and v \neq p_{i+1}]+\sum\limits_{i=j+1}^k\sum\limits_{fa_v=p_i}dp_v \times [v \neq p_{i+1}]+w
\]
\]
暴力显然是 \(n^2\) 的,考虑对其进行优化。
记 \(g_x=\sum\limits_{fa_y=x}dp_y-dp_x\),那么上式子等价于
\[\sum\limits_{i=1}^{j-1}g_{p_i}+\sum\limits_{fa_v=u}dp_v+\sum\limits_{i=j+1}^kg_{p_i}+w
\]
\]
用树链剖分随便维护一下就可以了
/*
Contest: -
Problem: Codeforces 856D
Author: tzc_wk
Time: 2020.8.20
*/
#include <bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define pb push_back
#define fz(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
#define foreach(it,v) for(__typeof(v.begin()) it=v.begin();it!=v.end();it++)
#define all(a) a.begin(),a.end()
#define fill0(a) memset(a,0,sizeof(a))
#define fill1(a) memset(a,-1,sizeof(a))
#define fillbig(a) memset(a,0x3f,sizeof(a))
#define y1 y1010101010101
#define y0 y0101010101010
typedef pair<int,int> pii;
typedef long long ll;
inline int read(){
int x=0,neg=1;char c=getchar();
while(!isdigit(c)){
if(c=='-') neg=-1;
c=getchar();
}
while(isdigit(c)) x=x*10+c-'0',c=getchar();
return x*neg;
}
int n=read(),m=read();
int head[200005],to[400005],nxt[400005],ecnt=0;
inline void adde(int u,int v){
to[++ecnt]=v;nxt[ecnt]=head[u];head[u]=ecnt;
}
int fa[200005],top[200005],dep[200005],siz[200005],wson[200005],dfn[200005],nd[200005],tim=0;
inline void dfs1(int x){
siz[x]=1;
for(int i=head[x];i;i=nxt[i]){
int y=to[i];dep[y]=dep[x]+1;dfs1(y);siz[x]+=siz[y];
if(siz[y]>siz[wson[x]]) wson[x]=y;
}
}
inline void dfs2(int x,int tp){
top[x]=tp;dfn[x]=++tim;nd[tim]=x;
if(wson[x]) dfs2(wson[x],tp);
for(int i=head[x];i;i=nxt[i]){
int y=to[i];if(y!=wson[x]) dfs2(y,y);
}
}
inline int getlca(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];
}
if(dep[x]<dep[y]) swap(x,y);
return y;
}
struct path{
int u,v,w;
path(){/*nothing*/}
path(int _u,int _v,int _w){u=_u;v=_v;w=_w;}
}; vector<path> t[200005];
struct node{
int l,r,val;
} s[200005<<2];
inline void build(int k,int l,int r){
s[k].l=l;s[k].r=r;
if(l==r) return;
int mid=(l+r)>>1;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
}
inline void modify(int k,int ind,int val){
if(s[k].l==s[k].r){s[k].val+=val;return;}
int mid=(s[k].l+s[k].r)>>1;
if(ind<=mid) modify(k<<1,ind,val);
else modify(k<<1|1,ind,val);
s[k].val=s[k<<1].val+s[k<<1|1].val;
}
inline int query(int k,int l,int r){
if(l<=s[k].l&&s[k].r<=r) return s[k].val;
int mid=(s[k].l+s[k].r)>>1;
if(r<=mid) return query(k<<1,l,r);
else if(l>mid) return query(k<<1|1,l,r);
else return query(k<<1,l,mid)+query(k<<1|1,mid+1,r);
}
inline int calc(int u,int v){
int sum=0;
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u,v);
sum+=query(1,dfn[top[u]],dfn[u]);u=fa[top[u]];
}
if(dep[u]<dep[v]) swap(u,v);
sum+=query(1,dfn[v],dfn[u]);return sum;
}
int dp[200005];
inline void dfs3(int x){
int sum=0;
for(int i=head[x];i;i=nxt[i]){
dfs3(to[i]);sum+=dp[to[i]];
}
dp[x]=sum;
foreach(it,t[x]) dp[x]=max(dp[x],calc(it->u,it->v)+it->w+sum);
modify(1,dfn[x],sum-dp[x]);
}
int main(){
for(int i=2;i<=n;i++){fa[i]=read();adde(fa[i],i);}
dfs1(1);dfs2(1,1);build(1,1,n);
for(int i=1;i<=m;i++){
int u=read(),v=read(),w=read();
t[getlca(u,v)].pb(path(u,v,w));
}
dfs3(1);printf("%d\n",dp[1]);
return 0;
}