分析
这道题思维程度实在高 (反正蒟蒻在考场上打不出来orz)
本不该是第二题的难度QWQ
运用LCA+桶+树上差分
对于当前点x,找出对它所有有贡献的点(是起点,并且能够在w[i]时间到达i点)
先用树上倍增/Tarjan求出所有起点s[i]和终点t[i]的LCA,并用dis[i]数组保存s[i]到t[i]的距离。
怎么找对每一个点有贡献的点呢?
设起点u,终点v,u和v的最近公共祖先lca,deep[x]表示点x的深度。此时的dis[i]为u到v的距离
情况一: 从u经过点i到 lca (即从下向上走)
若deep[x]+w[x]=deep[u]
则点u对于i有贡献 因为(从u点出发到达i点,刚好经过w[i]时间)
情况二: 从u经过lca到点i (即先向上再向下走)
若w[x]-deep[x]=dis[i]-deep[v]
则点v对于i有贡献(等同于点u对于i有贡献
(同理从u点出发到达i点,刚好经过w[i]时间)
这个可能说起来有点抽象,但只要画个图就一目了然啦~
最后将两种情况合并起来,统计(运用dfs从下向上统计)所有对当前点x有贡献的点的个数,保存在ans[x]中,最后输出就可以啦。
代码如下
#include<bits/stdc++.h>
using namespace std;
const int N=300000;
int n,m,a[N],head[N],h1[N],h2[N],f[N][20];
int deep[N],s[N],t[N],st[N];
int w[N],ans[N],tot,c1,c2;
int b1[N*2],b2[N*2],dis[N];
struct edge{
int ver,to;
}e[N*2],e1[N*2],e2[N*2];
int read(){
int sum=0,f=1;
char ch=getchar();
while(ch>'9'||ch<'0')
{
if(ch=='-')f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
sum=(sum<<3)+(sum<<1)+ch-'0';
ch=getchar();
}
return sum*f;
}
void add(int x,int y)
{
e[++tot].ver=y;
e[tot].to =head[x];
head[x]=tot;
}
void add1(int x,int y)
{
e1[++c1].ver=y;
e1[c1].to =h1[x];
h1[x]=c1;
}
void add2(int x,int y)
{
e2[++c2].ver=y;
e2[c2].to =h2[x];
h2[x]=c2;
}
void dfs(int x){
for(int i=1;(1<<i)<=deep[x];i++){
f[x][i]=f[f[x][i-1]][i-1];
}
for(int i=head[x];i;i=e[i].to){
int y=e[i].ver;
if(y==f[x][0])continue;
f[y][0]=x;
deep[y]=deep[x]+1;
dfs(y);
}
}
int lca(int x,int y){
if(x==y)return x;
if(deep[y]>deep[x])swap(x,y);
int t=log(deep[x]-deep[y])/log(2);
for(int i=t;i>=0;i--)
if(deep[f[x][i]]>=deep[y])
{
x=f[x][i];
}
if(x==y)return x;
t=log(deep[x])/log(2);
for(int i=t;i>=0;i--)
{
if(f[x][i]!=f[y][i])
{
x=f[x][i];
y=f[y][i];
}
}
return f[x][0];
}
void dfs2(int x){
int t1=b1[w[x]+deep[x]],t2=b2[w[x]-deep[x]+N];
for(int i=head[x];i;i=e[i].to){
int y=e[i].ver;
if(y==f[x][0])continue;
dfs2(y);
}
b1[deep[x]]+=st[x];
for(int i=h1[x];i;i=e1[i].to){
int y=e1[i].ver;
b2[dis[y]-deep[t[y]]+N]++;
}
ans[x]+=b1[w[x]+deep[x]]-t1+b2[w[x]-deep[x]+N]-t2;
for(int i=h2[x];i;i=e2[i].to){
int y=e2[i].ver;
b1[deep[s[y]]]--;
b2[dis[y]-deep[t[y]]+N]--;
}
}
int main(){
// freopen("running.in","r",stdin);
// freopen("running.out","w",stdout);
n=read();
m=read();
int u,v;
for(int i=1;i<n;i++)
{
u=read();
v=read();
add(u,v);
add(v,u);
}
deep[1]=1;
f[1][0]=1;
dfs(1);
for(int i=1;i<=n;i++)
w[i]=read();
for(int i=1;i<=m;i++)
{
s[i]=read();
t[i]=read();
int fx=lca(s[i],t[i]);
st[s[i]]++;
dis[i]=deep[s[i]]+deep[t[i]]-2*deep[fx];
add1(t[i],i);
add2(fx,i);
if(deep[fx]+w[fx]==deep[s[i]]) ans[fx]--;
}
dfs2(1);
for(int i=1;i<=n;i++)
cout<<ans[i]<<' ';
return 0;
}