Description
给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v),你需要回答u xor lastans和v这两个节点间有多少种不同的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。
Input
第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v),表示一组询问。
数据范围是N<=40000 M<=100000 点权在int范围内
Output
M行,表示每个询问的答案。
1.每次在树上找到一棵树高为sqrt(n)的子树分为一块并删去,最后剩下树高不足sqrt(n)的部分(如果有)另成一块,这样树就被分为至多sqrt(n)块,且每块高度不超过sqrt(n)
2.对每个块的树根,预处理出这个点到树上所有点的路径的答案
3.对每个点,预处理出这个点到1号点的路径上每种点权出现的最大深度,用可持久化块状数组保存(数组分为sqrt(n)块,每个节点保存指向这些块的指针,修改时暴力复制所修改的块并修改)
4.对一个询问(u,v),若u,v在同一块则暴力,否则令u所在块的根的深度大于v所在块的根的深度,设x为u所在块的根,则x到v的答案已预处理好,根据第3步预处理的内容可以查询u到x的路径上(不包括x)的点权是否在x到v的路径上出现过从而得到答案
以上每一步均是O(n3/2)时间复杂度,常数非常大,可能需要一些常数优化
upd: (1)中分块可在O(n)时间完成,另外bitset优化的 树分块+ST表(同bzoj4763)可以做到n2/32,且实际运行效果比上述算法更好
#include<bits/stdc++.h>
const int N=;
int n,m,B;
int read(){
int x=,c=getchar();
while(c<)c=getchar();
while(c>)x=x*+c-,c=getchar();
return x;
}
int v[N],vs[N],e[N*][],e0[N],ep=,la=;
int id[N],rt[N],idp=,fa[N],dep[N],t[N],ANS=;
int sz[N],pf[N],top[N];
int ans[][N],h[N];
int mem[*N],*ptr=mem;
struct Array{
int*arr[];
const int&operator[](int x){
return arr[x>>][x&];
}
void copy(Array&src,int x,int y){
memcpy(&arr,&src.arr,*);
ptr+=;
memcpy(ptr,arr[x>>],*);
ptr[x&]=y;
arr[x>>]=ptr;
}
}as[N];
int lca(int x,int y){
int a=top[x],b=top[y];
while(a!=b){
if(dep[a]<dep[b])std::swap(a,b),std::swap(x,y);
x=fa[a];a=top[x];
}
return dep[x]<dep[y]?x:y;
}
int vio(int x,int y){
int a=lca(x,y),r=;
for(int w=x;w!=a;w=fa[w])if(!t[v[w]]++)++r;
for(int w=y;w!=a;w=fa[w])if(!t[v[w]]++)++r;
if(!t[v[a]])++r;
for(int w=x;w!=a;w=fa[w])t[v[w]]=;
for(int w=y;w!=a;w=fa[w])t[v[w]]=;
return r;
}
int query(int x,int y){
if(id[x]==id[y])return vio(x,y);
if(dep[rt[id[x]]]<dep[rt[id[y]]])std::swap(x,y);
int d=dep[lca(x,y)],b=rt[id[x]];
int r=ans[id[b]][y];
for(int w=x;w!=b;w=fa[w]){
int c=v[w];
if(!t[c]&&as[b][c]<d&&as[y][c]<d)++r,t[c]=;
}
for(int w=x;w!=b;w=fa[w])t[v[w]]=;
return r;
}
void f3(int w,int pa,int ID){
if(!t[v[w]]++)++ANS;
ans[ID][w]=ANS;
for(int i=e0[w];i;i=e[i][]){
int u=e[i][];
if(u!=pa)f3(u,w,ID);
}
if(!--t[v[w]])--ANS;
}
void f1(int w,int pa){
sz[w]=;
fa[w]=pa;
dep[w]=dep[pa]+;
for(int i=e0[w];i;i=e[i][]){
int u=e[i][];
if(u!=pa){
f1(u,w);
sz[w]+=sz[u];
if(sz[u]>sz[pf[w]])pf[w]=u;
}
}
}
void f2(int w){
for(int i=e0[w];i;i=e[i][]){
int u=e[i][];
if(u!=fa[w]&&!id[u]){
id[u]=id[w];
f2(u);
}
}
}
void f4(int w){
as[w].copy(as[fa[w]],v[w],dep[w]);
for(int i=e0[w];i;i=e[i][]){
int u=e[i][];
if(u!=fa[w])f4(u);
}
}
void f5(int w,int tp){
top[w]=tp;
if(pf[w])f5(pf[w],tp);
for(int i=e0[w];i;i=e[i][]){
int u=e[i][];
if(u!=fa[w]&&u!=pf[w])f5(u,u);
}
}
void f6(int w){
h[w]=;
for(int i=e0[w];i;i=e[i][]){
int u=e[i][];
if(u!=fa[w]&&!id[u]){
f6(u);
if(h[u]>=h[w])h[w]=h[u]+;
}
}
}int main(){
n=read();m=read();
B=sqrt(n+)+;
for(int i=;i<=n;i++)vs[i]=v[i]=read();
std::sort(vs+,vs+n+);
for(int i=;i<=n;i++)v[i]=std::lower_bound(vs+,vs+n+,v[i])-vs;
for(int i=;i<n;i++){
int a=read(),b=read();
e[ep][]=b;e[ep][]=e0[a];e0[a]=ep++;
e[ep][]=a;e[ep][]=e0[b];e0[b]=ep++;
}
for(int i=;i<;i++)as[].arr[i]=mem;
f1(,);
do{
f6();
int r=;
for(int i=;i<=n;i++)if(!id[i]&&h[i]==B){
r=i;
break;
}
rt[id[r]=++idp]=r;
f2(r);
f3(r,,idp);
}while(!id[]);
f4();f5(,);
while(m--){
int x=read()^la,y=read();
printf("%d\n",la=query(x,y));
}
return ;
}