题意
给出一棵树,每条边有一种颜色,每种颜色都有相应的权值。一条路径的权值为:连续相同的颜色视为一段,权值为所有颜色段对应的权值和,比如路径颜色为:1 1 1 2 3 1 1 4 4,那么颜色段为:1 2 3 1 4。
求长度在 [ L , R ] [L, R] [L,R]范围内的路径的最大权值
题解
看到树上路径问题显然想到点分治,但是求出子树的各路径权值后,不能直接选最大的相加,假如两条子树路径的初始颜色一样那就得减去一次初始颜色权值。所以得开两颗线段树(线段树用来维护路径长度在 [ x , y ] [x,y] [x,y]的最大权值),一颗来维护对于当前颜色,一颗用来维护总颜色,然后当子树初始颜色变化时,将当前颜色合并到总颜色,并清空当前颜色。
这题把两颗线段树合起来写更方便,但我懒得改了
#include<iostream>
#include<sstream>
#include<string>
#include<queue>
#include<map>
#include<unordered_map>
#include<set>
#include<vector>
#include<stack>
#include<utility>
#include<list>
#include<bitset>
#include<algorithm>
#include<cstdio>
#include<cmath>
#include<cstdlib>
#include<cstring>
#include<iomanip>
#include<time.h>
#include<random>
using namespace std;
#include<ext/pb_ds/priority_queue.hpp>
#include<ext/pb_ds/tree_policy.hpp>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/hash_policy.hpp>
using namespace __gnu_pbds;
#include<ext/rope>
using namespace __gnu_cxx;
#define PI acos(-1.0)
#define eps 1e-9
#define lowbit(a) ((a)&-(a))
#define mid ((l+r)>>1)
#define mem(x,y) memset(x,y,sizeof x)
const int mod = 1e9+7;
int qpow(int a,int b){
int ans=1;
while(b){
if(b&1)ans=(ans*a)%mod;
a=(a*a)%mod;
b>>=1;
}
return ans;
}
const int INF = 0x3f3f3f3f;
const int N = 6e6+10;
struct SGT{
int l,r;
int w;
}tot[N*4],sam[N*2];
int rt1,rt2,cnt1,cnt2;
int n,m,L,R;
struct node{
int to,w;
bool operator<(node x){
return w<x.w;
}
};
struct Node{
int to,next,w;
}e[N*2];
int head[N],k;
int alls,son[N],MX,RT,sz[N],vis[N];
int c[N];
int top,ans=-INF;
pair<int,int>pat[N];
void add(int u,int v,int w){
e[++k]={v,head[u],w}; head[u]=k;
}
void merge(int &x,int y,int l=0,int r=n){
if(!x)tot[x=++cnt1].w=-INF;
if(!y)return;
tot[x].w=max(tot[x].w,sam[y].w);
merge(tot[x].l,sam[y].l,l,mid);
merge(tot[x].r,sam[y].r,mid+1,r);
}
void ins(int &now,int ver,int l,int r,int pos,int val){
if(!now){
if(ver==2)sam[now=++cnt2].w=-INF;
else tot[now=++cnt1].w=-INF;
}
if(l==r){
if(ver==2)sam[now].w=max(sam[now].w,val);
else tot[now].w=max(tot[now].w,val);
return;
}
if(pos<=mid){
if(ver==2) ins(sam[now].l,ver,l,mid,pos,val);
else ins(tot[now].l,ver,l,mid,pos,val);
}
else{
if(ver==2) ins(sam[now].r,ver,mid+1,r,pos,val);
else ins(tot[now].r,ver,mid+1,r,pos,val);
}
if(ver==2)sam[now].w=max(sam[now].w,val);
else tot[now].w=max(tot[now].w,val);
}
int query(int now,int ver,int l,int r,int ql,int qr){
if(!now)return -INF;
if(l>qr||r<ql)return -INF;
if(ql<=l&&r<=qr)return ver==1?tot[now].w:sam[now].w;
if(ver==2) return max(query(sam[now].l,ver,l,mid,ql,qr),query(sam[now].r,ver,mid+1,r,ql,qr));
else return max(query(tot[now].l,ver,l,mid,ql,qr),query(tot[now].r,ver,mid+1,r,ql,qr));
}
void getroot(int u,int fa){
sz[u]=1,son[u]=0;
for(int i=head[u];i;i=e[i].next){
int v=e[i].to;
if(v==fa||vis[v])continue;
getroot(v,u);
sz[u]+=sz[v];
son[u]=max(son[u],sz[v]);
}
son[u]=max(son[u],alls-sz[u]);
if(son[u]<MX)RT=u,MX=son[u];
}
void getdis(int u,int fa,int val,int len,int lastc){
if(len>R)return;
pat[++top].first=len,pat[top].second=val;
for(int i=head[u];i;i=e[i].next){
int v=e[i].to,w=e[i].w;
if(v==fa||vis[v])continue;
getdis(v,u,w==lastc?val:val+c[w],len+1,w);
}
}
void cal(int u){
rt1=rt2=0;
ins(rt1,1,0,n,0,0);
vector<node>bch;
for(int i=head[u];i;i=e[i].next){
int v=e[i].to,w=e[i].w;
if(!vis[v])bch.push_back((node){v,w});
}
sort(bch.begin(),bch.end());
int m=bch.size();
for(int i=0;i<m;i++){
if(i&&bch[i].w!=bch[i-1].w)merge(rt1,rt2),rt2=0;
top=0;
getdis(bch[i].to,u,c[bch[i].w],1,bch[i].w);
for(int j=1;j<=top;j++){
ans=max(ans,query(rt1,1,0,n,max(0,L-pat[j].first),R-pat[j].first)+pat[j].second);
ans=max(ans,query(rt2,2,0,n,max(0,L-pat[j].first),R-pat[j].first)+pat[j].second-c[bch[i].w]);
}
for(int j=1;j<=top;j++){
ins(rt2,2,0,n,pat[j].first,pat[j].second);
}
}
}
void div(int now){
vis[now]=1;cal(now);
for(int i=head[now];i;i=e[i].next){
int v=e[i].to;
if(vis[v])continue;
alls=son[v],RT=0,MX=INF;
getroot(v,v);
div(RT);
}
}
#define endl '\n'
signed main(){
std::ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
cin>>n>>m>>L>>R;
for(int i=1;i<=m;i++)cin>>c[i];
for(int i=1;i<n;i++){
int u,v,w; cin>>u>>v>>w;
add(u,v,w),add(v,u,w);
}
alls=n,MX=INF;
getroot(1,1);
div(RT);
cout<<ans<<endl;
}