一道非常精彩,同时也很经典的题目。和这场的 C 一样经典
首先看到这个数据范围先猜正解复杂度:\(n\) 级别大于 \(q\),所以大概是 \(n\sqrt{n}\log n+q\sqrt{n}\),事实的确如此。
按照最常规的思路,看到树上路径统计的问题无非两种可能,树链剖分、倍增这一类在线数据结构或者离线下来一遍 DFS,把询问挂在 \(u\) 或 \(v\) 某一个节点上然后求解,但不管你使用哪一种,仔细想想都会发现需要涉及这样一个问题:对于两个序列 \(a,b\),我们要支持在已知 \(a_i\oplus b_i\) 的最小值的情况下,快速对给定的 \(x\) 求出 \((a_i+x)\oplus b_i\) 的最小值。而众所周知这是不好维护的,因此考虑不常规的数据结构,也就是分块。
我们考虑对于 \(u\to v\) 这条链上的元素,从下到上每 \(256\) 个元素一块,最后多出来的那一部分元素 \((\le 255)\) 暴力统计贡献即可,复杂度显然不会出问题。因此接下来我们的任务就是考虑怎样计算整块的贡献,我们这里假设编号从 \(0\) 开始,第 \(0\) 块表示 \(u\) 到 \(u\) 的 \(255\) 级祖先路径上的点,第 \(1\) 块表示 \(u\) 的 \(256\) 级祖先到 \(511\) 级祖先路径上的点,以此类推。那么假设第 \(i\) 块中节点的权值分别是 \(b_0,b_1,\cdots,b_{255}\),那么这一段对答案的贡献就是 \(\max\limits_{j=0}^{255}b_j\oplus(256i+j)\),我们假设这一块中第 \(0\) 个节点(也就是 \(u\) 的 \(256i\) 级祖先)为 \(x\),那么我们考虑设 \(f_{x,i}=\max\limits_{j=0}^{255}a_{fa_{x,j}}\oplus(256i+j)\),其中 \(fa_{x,j}\) 表示 \(x\) 的 \(j\) 级祖先,那么显然这一块对答案的贡献就是 \(f_{x,i}\),因此我们只要能预处理出 \(f_{x,i}\) 就可以计算出答案。
接下来考虑怎样预处理出 \(f_{x,i}\),如果我们设 \(b_i=\lfloor\dfrac{a_i}{256}\rfloor,c_i=a_i\ \text{and}\ 255\),那么式子可以写作 \(f_{x,i}=\max\limits_{j=0}^{255}(b_{fa_{x,j}}\oplus i)·256+(c_{fa_{x,j}}\oplus j)\),不难发现两部分是独立的,因此考虑这样计算 \(f_{x,i}\):对于每个 \(x\),计算所有 \(f_{x,i}\) 时,先将它的 \(0\sim 255\) 级祖先的 \(b\) 插入 trie 中,这样对于所有 \(i\) 在 trie 树上查询一下即可知道 \(b_{fa_{x,j}}\oplus i\) 的最大值,而后面那部分与 \(i\) 无关,我们假设 \(b_{fa_{x,j}}\oplus i\) 取到最大值时,\(b_{fa_{x,j}}=v\),那么后面那部分的最大值就是,所有 \(b_{fa_{x,j}}=v\) 的 \(j\) 中 \((c_{fa_{x,j}}\oplus j)\) 的最大值,因此我们考虑再开一个桶 \(mxv_v\) 维护这个东西,而这显然是可以在枚举 \(i\) 之前完成的,因此枚举 \(i\) 求出 \(b_{fa_{x,j}}\oplus i\) 之后就可以 \(\mathcal O(1)\) 地求出后面那部分的最大值。这样我们就在 \(n\sqrt{n}\log n\) 的时间内求出 \(f_{x,i}\)
总复杂度确实是 \(n\sqrt{n}\log n+q\sqrt{n}\)
所以这大概也算树分块的一个小应用了吧,之前只刷了树分块的模板题却不知道它的应用。
const int MAXN=65536;
const int MAXP=2048;
int n,qu,a[MAXN+5],hd[MAXN+5],to[MAXN*2+5],nxt[MAXN*2+5],ec=0;
void adde(int u,int v){to[++ec]=v;nxt[ec]=hd[u];hd[u]=ec;}
int dep[MAXN+5],fa[MAXN+5][10];
void dfs0(int x,int f){
fa[x][0]=f;
for(int e=hd[x];e;e=nxt[e]){
int y=to[e];if(y==f) continue;
dep[y]=dep[x]+1;dfs0(y,x);
}
}
int mx[MAXN+5][258],mxv[258];
int ch[MAXP+5][2],ncnt=0,rt=0;
void insert(int &k,int d,int v){
if(!k) k=++ncnt;if(!~d) return;
insert(ch[k][v>>d&1],d-1,v);
}
int query(int k,int d,int v){
if(!~d) return 0;
if(ch[k][~v>>d&1]) return query(ch[k][~v>>d&1],d-1,v)|(1<<d);
else return query(ch[k][v>>d&1],d-1,v);
}
void clear(){memset(ch,0,sizeof(ch));ncnt=rt=0;}
int main(){
scanf("%d%d",&n,&qu);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
for(int i=1,u,v;i<n;i++) scanf("%d%d",&u,&v),adde(u,v),adde(v,u);dfs0(1,0);
for(int i=1;i<=8;i++) for(int j=1;j<=n;j++) fa[j][i]=fa[fa[j][i-1]][i-1];
for(int i=1;i<=n;i++){
clear();memset(mxv,0,sizeof(mxv));
for(int j=i,stp=0;j&&stp<256;j=fa[j][0],stp++){
chkmax(mxv[a[j]>>8],stp^(a[j]&255));
insert(rt,7,a[j]>>8);
}
for(int j=0;j<256;j++){
int p=query(rt,7,j);
mx[i][j]=(p<<8)|mxv[p^j];
}
} while(qu--){
int u,v,p;scanf("%d%d",&u,&v);p=v;int cur=0,res=0;
while(dep[p]-dep[u]>255) chkmax(res,mx[p][cur]),p=fa[p][8],++cur;
while(p^u) chkmax(res,a[p]^(dep[v]-dep[p])),p=fa[p][0];
chkmax(res,a[u]^(dep[v]-dep[u]));printf("%d\n",res);
}
return 0;
}