这题搞了我一下午……因为一些傻X的问题……
对于步长大于sqrt(n)的询问,我们可以直接暴力求解
然后,我们可以事先预处理出d[u][step]表示u往上跳,每次跳step步,直到跳到不能跳为止,所获得的分数,其中step<=K
那么对于步长小于sqrt(n)的询问,我们直接查表然后一系列运算即可
各种细节自己yy吧
复杂度应该是O(nsqrt(n)logn)
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#define N 50004
#define M 100005 using namespace std;
inline int read(){
int ret=0;char ch=getchar();
while (ch<'0'||ch>'9') ch=getchar();
while ('0'<=ch&&ch<='9'){
ret=ret*10-48+ch;
ch=getchar();
}
return ret;
} struct edge{
int adj,next;
edge(){}
edge(int _adj,int _next):adj(_adj),next(_next){}
} e[M];
int n,g[N],m;
void AddEdge(int u,int v){
e[++m]=edge(v,g[u]);g[u]=m;
e[++m]=edge(u,g[v]);g[v]=m;
} const int K=500;
int a[N],f[N][23],d[N][K+5];
int fa[N],deep[N];
void dfs(int u){
deep[u]=deep[fa[u]]+1;
int v=fa[u];
for (int i=1;i<=K;++i,v=fa[v])
d[u][i]=d[v][i]+a[u];
for (int i=g[u];i;i=e[i].next){
int v=e[i].adj;
if (v==fa[u]) continue;
fa[v]=u;
dfs(v);
}
} void precompute(){
fa[1]=fa[0]=0;deep[0]=0;
memset(d[0],0,sizeof(d[0]));
dfs(1);
for (int i=1;i<=n;++i) f[i][0]=fa[i];
memset(f[0],0,sizeof(f[0]));
for (int k=1;k<=20;++k)
for (int i=1;i<=n;++i)
f[i][k]=f[f[i][k-1]][k-1];
} int jump(int u,int step){
for (int k=0;k<=20;++k)if ((step&(1<<k))>0)u=f[u][k];
return u;
}
int qlca(int u,int v){
if (deep[u]<deep[v]) swap(u,v);
u=jump(u,deep[u]-deep[v]);
for (int k=20;k>=0;--k)if (f[u][k]!=f[v][k])u=f[u][k],v=f[v][k];
return u==v?u:fa[u];
} int s[N],t[N];
int main(){
n=read();
for (int i=1;i<=n;++i) a[i]=read();
memset(g,0,sizeof(g));m=1;
for (int i=1;i<n;++i) AddEdge(read(),read());
precompute();
for (int i=1;i<=n;++i) t[i-1]=s[i]=read();
for (int i=1;i<n;++i){
int step=read(),lca=qlca(s[i],t[i]),u=s[i],v=t[i],res=0;
if ((deep[u]+deep[v]-2*deep[lca])%step>0){
res+=a[v];
v=jump(v,(deep[u]+deep[v]-2*deep[lca])%step);
}
if (deep[lca]%step==deep[u]%step) res+=a[lca];
if (step<=K){
int top1=jump(lca,(deep[lca]%step-deep[u]%step+step)%step),top2=jump(lca,(deep[lca]%step-deep[v]%step+step)%step);
res+=d[u][step]-d[top1][step]+d[v][step]-d[top2][step];
}
else for (int j=0;j<2;++j)
for (swap(u,v);deep[u]>deep[lca];u=jump(u,step))
res+=a[u];
printf("%d\n",res);
}
return 0;
}