来一发清新的80行 树剖 $LCA$ + 树上差分 题解。
-----from Judge
本题题意大概是给出一棵 n 个节点的树以及 m 条有向路径,
并且每个点 i 都有一个权值 $w[i]$,如果某条路径包含了 i 号节点,并且 i 号节点是该路径上的第 $w[i]$ 个节点的话就会对答案产生贡献。
考虑暴力做法
我们可以十分理所当然的想到一个暴力: u 和 v 向上跑,沿路判断条件并累加,直到跑到 $LCA$ 。
但是这个暴力的时间复杂度太高了: $O (n + m )$ 。于是我们考虑别的方法。
首先我们看 $n$ 和 $m$ 都是 $3e5$ 的数据范围,那么我们对于 $m$ 条路径的处理当然不能太大, 那么我们就考虑:用 **起点** 和 **终点** 来累加答案。
考虑路径的拆分
那么如何累加? 我们先考虑:每个节点要对答案产生贡献则必然出现在一条路径$(u->v)$ 上,又因为 $u$ 先会跑到 $LCA(u,v)$ ,然后再跑到 $v$。
即:$u$ -> $LCA(u,v)$ -> $v$
那么我们可以把路径拆成两条,分别处理 $u$ 到 $lca$ 路径上点的贡献,以及 $lca$ 到 $v$ 路径上点的贡献
考虑第一条路径上的点
我们可以推导出,在 $dep[i] + w[i] = dep[u]$ 的情况先, i 号节点会对答案产生贡献。
即:在 i 的子树内,若有 $dep[u] = dep[i] + w[i]$ (至于 i 不在 $u -> LCA$ 的路径上的情况我们可以用差分思想解决)
那么我们只需要判断 i 的子树内 有无 $dep[u]$ 满足该式即可,对此我们可以开一个计数数组 $cnt$ 。
然后对于一个点我们可以记录下当前点出发的节点数,每次深搜到该点就对计数数组累加,
同时我们扫描当前节点作为哪些子节点的 $LCA$ 出现过,将这些子节点的 $dep$ 减掉就可以达到差分的效果了。
考虑第二条路径上的点
类似的,我们可以推导出关于 u,v 和 i 的等式:
$dep[i] - w[i] = dep[v] - dis(u,v)$
那么这里 dep 减掉 w[i] 后可能是负数,对此我们将左右式同时加上 n 即可(题目条件:$w[i]<=n$)。
那么这样我们就要判断 i 的子树内是否有 $dep[v] - dis(u,v) = dep[i] - w[i]$ 即可。
然后同上操作。
遍历处理
那么我们只需要对树进行两次遍历就可以处理出以上信息了。
考虑重复贡献的删除
我们可以观察到,如果某个节点 $i$ 就是 $u$ 、$v$ 的 $LCA$
那么该节点的贡献是会被累加两次的。对此我们如何消除多余贡献?这里有两种方法:
暴力删除
在程序最后暴力枚举 m 条路径 的 LCA 并判断其是否在该路径上,在的话就减贡献。
修改遍历树的方式
其实我们完全可以在第二次遍历树的时候先对以 $i$ 为 $LCA$ 的路径上的终点信息先删除,然后再累加答案,相当于我们在做第二条路径的时候忽略掉 $LCA$ 这个节点。那么对此的操作也很简单,程序执行顺序换一下就好了。
//by Judge
#include<iostream>
#include<vector>
#include<cstdio>
#define ll long long
using namespace std;
const int M=3e5+;
const int inf=1e9+;
#ifndef Judge
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
#endif
char buf[<<],*p1=buf,*p2=buf;
inline int read(){
int x=,f=; char c=getchar();
for(;!isdigit(c);c=getchar()) if(c=='-') f=-;
for(;isdigit(c);c=getchar()) x=x*+c-''; return x*f;
} char sr[<<],z[];int C=-,Z;
inline void Ot(){fwrite(sr,,C+,stdout),C=-;}
inline void print(int x){
if(C><<)Ot();if(x<)sr[++C]=,x=-x;
while(z[++Z]=x%+,x/=);
while(sr[++C]=z[Z],--Z);sr[++C]=' ';
} int n,m,pat,mx;
int w[M],s[M],cnt1[M],cnt2[M<<],ans[M],siz[M],dep[M],son[M],f[M],top[M],head[M];
vector<int> q1[M],q2[M],q3[M];
struct operation{ int u,v,lca,dis; }a[M];
struct Edge{ int to,next; Edge(int to,int next):to(to),next(next){} Edge(){} }e[M<<];
inline void add(int u,int v){
e[++pat]=Edge(v,head[u]),head[u]=pat;
e[++pat]=Edge(u,head[v]),head[v]=pat;
}
#define v e[i].to
void dfs(int u,int fa){
siz[u]=,dep[u]=dep[f[u]=fa]+;
for(int i=head[u];i;i=e[i].next) if(v^fa){
dfs(v,u),siz[u]+=siz[v];
if(siz[v]>siz[son[u]]) son[u]=v;
}
} void dfs(int u){
if(!top[u]) top[u]=u; if(!son[u]) return ;
top[son[u]]=top[u], dfs(son[u]);
for(int i=head[u];i;i=e[i].next)
if(v^f[u] && v^son[u]) dfs(v);
} void dfs1(int u){
int now=w[u]+dep[u],cun; if(now<=mx) cun=cnt1[now];
for(int i=head[u];i;i=e[i].next) if(v^f[u]) dfs1(v);
cnt1[dep[u]]+=s[u]; if(now<=mx) ans[u]=cnt1[now]-cun;
for(int i=;i<q1[u].size();++i) --cnt1[dep[q1[u][i]]];
} void dfs2(int u){
int now=dep[u]-w[u]+n,cum=cnt2[now];
for(int i=head[u];i;i=e[i].next) if(v^f[u]) dfs2(v);
for(int i=;i<q2[u].size();++i) ++cnt2[q2[u][i]+n];
for(int i=;i<q3[u].size();++i) --cnt2[q3[u][i]+n];
ans[u]+=cnt2[now]-cum;
}
#undef v
inline int LCA(int u,int v){
while(top[u]^top[v]){
dep[top[u]]>dep[top[v]]?u=f[top[u]]:v=f[top[v]];
} return dep[u]<dep[v]?u:v;
}
int main(){
n=read(),m=read();
for(int i=,u,v;i<n;++i)
u=read(),v=read(),add(u,v);
for(int i=;i<=n;++i) w[i]=read();
for(int i=;i<=m;++i) a[i].u=read(),a[i].v=read();
dep[]=,dfs(,),dfs(); for(int i=;i<=n;++i) mx=max(mx,dep[i]);
for(int i=;i<=m;++i){
a[i].lca=LCA(a[i].u,a[i].v),++s[a[i].u];
a[i].dis=dep[a[i].u]+dep[a[i].v]-dep[a[i].lca]*;
q1[a[i].lca].push_back(a[i].u);
} dfs1();
for(int i=;i<=m;++i){
q2[a[i].v].push_back(dep[a[i].v]-a[i].dis);
q3[a[i].lca].push_back(dep[a[i].v]-a[i].dis);
} dfs2();
for(int i=;i<=n;++i) print(ans[i]); return Ot(),putchar('\n'),;
}