还是树剖模板啊QwQ
#include<cstdio> #include<iostream> #include<cstring> #include<algorithm> #define maxn 600100 using namespace std; struct node { long long max_num,sum; }; node tree[maxn<<2]; int n,m; long long num[maxn],num_rnk[maxn]; int dep[maxn],fa[maxn],son[maxn],siz[maxn],rnk[maxn],top[maxn],rnk_cnt; long long tot,ans; struct nn { int ed,nxt; }; nn edge[maxn<<1]; int cnt,first[maxn]; inline void add_edge(int s,int e) { ++cnt; edge[cnt].ed=e; edge[cnt].nxt=first[s]; first[s]=cnt; return; } inline void dfs_1(int now,int pre) { fa[now]=pre; dep[now]=dep[pre]+1; siz[now]=1; for(register int i=first[now];i;i=edge[i].nxt) { int e=edge[i].ed; if(e==fa[now]) continue; dfs_1(e,now); siz[now]+=siz[e]; if(son[now]==-1||siz[e]>siz[son[now]]) son[now]=e; } return; } inline void dfs_2(int now,int heavy_fa) { ++rnk_cnt; rnk[now]=rnk_cnt; num_rnk[rnk_cnt]=num[now]; top[now]=heavy_fa; if(son[now]==-1) return; dfs_2(son[now],heavy_fa); for(register int i=first[now];i;i=edge[i].nxt) { int e=edge[i].ed; if(e!=fa[now]&&e!=son[now]) dfs_2(e,e); } return; } inline void build(int k,int l,int r) { if(l==r) { tree[k].sum=num_rnk[l]; tree[k].max_num=num_rnk[l]; return; } int mid=(l+r)>>1,son=k<<1; build(son,l,mid); build(son|1,mid+1,r); tree[k].sum=tree[son].sum+tree[son|1].sum; tree[k].max_num=max(tree[son].max_num,tree[son|1].max_num); return; } inline void modify(int k,int l,int r,int x,int y,long long v) { if(r<x||l>y) return; if(x<=l&&r<=y) { tree[k].sum=v; tree[k].max_num=v; return; } int mid=(l+r)>>1,son=k<<1; modify(son,l,mid,x,y,v); modify(son|1,mid+1,r,x,y,v); tree[k].sum=tree[son].sum+tree[son|1].sum; tree[k].max_num=max(tree[son].max_num,tree[son|1].max_num); return; } inline long long query(int k,int l,int r,int x,int y) { if(r<x||l>y) return -1e18; if(x<=l&&r<=y) return tree[k].max_num; int mid=(l+r)>>1,son=k<<1; return max(query(son,l,mid,x,y),query(son|1,mid+1,r,x,y)); } inline void tree_max(int x,int y) { ans=-1e18; while(top[x]!=top[y]) { if(dep[top[x]]<dep[top[y]]) swap(x,y); ans=max(ans,query(1,1,n,rnk[top[x]],rnk[x])); x=fa[top[x]]; } ans=max(ans,query(1,1,n,min(rnk[x],rnk[y]),max(rnk[x],rnk[y]))); return; } inline void get_sum(int k,int l,int r,int x,int y) { if(r<x||l>y) return; if(x<=l&&r<=y) { tot+=tree[k].sum; return; } int mid=(l+r)>>1,son=k<<1; get_sum(son,l,mid,x,y); get_sum(son|1,mid+1,r,x,y); return; } inline void tree_sum(int x,int y) { while(top[x]!=top[y]) { if(dep[top[x]]<dep[top[y]]) swap(x,y); tot=0; get_sum(1,1,n,rnk[top[x]],rnk[x]); ans+=tot; x=fa[top[x]]; } tot=0; get_sum(1,1,n,min(rnk[x],rnk[y]),max(rnk[x],rnk[y])); ans+=tot; return; } int main() { //freopen("count1.in","r",stdin); //freopen("gun.out","w",stdout); scanf("%d",&n); for(register int i=1;i<=n-1;++i) { int s,e; scanf("%d%d",&s,&e); add_edge(s,e); add_edge(e,s); } for(register int i=1;i<=n;++i) scanf("%lld",&num[i]); memset(son,-1,sizeof(son)); dfs_1(1,0); dfs_2(1,1); build(1,1,n); scanf("%d",&m); for(register int i=1;i<=m;++i) { char s[10]; scanf("%s",s); int len=strlen(s); if(len==6) { int x; long long v; scanf("%d%lld",&x,&v); modify(1,1,n,rnk[x],rnk[x],v); } else if(len==4&&s[1]=='M') { int x,y; scanf("%d%d",&x,&y); ans=-1e18; tree_max(x,y); printf("%lld\n",ans); } else { int x,y; scanf("%d%d",&x,&y); ans=0; tree_sum(x,y); printf("%lld\n",ans); } } return 0; }