题目
给出一棵树,每个节点上有权值\(a_i\),多次询问一条路径上选择一些点权值异或和最大值。\(n\le 2\times 10^4,q\le 2\times 10^5,0\le a_i\le 2\times 2^{60}\)。
分析
选择一些点的异或和最大值显然用到线性基,这又是一个树上的路径问题,所以可以考虑倍增。预处理出倍增线性基,查询的时候倍增合并线性基,利用线性基的方法查询(从高往低能变大就异或)最大值。
这个总复杂度为\(O((n+q)\log n\log^2 a)\)。(线性基插入为\(\log a\),合并的时候有\(\log a\)个要插入)。
算起来是非常大的呢!但是可以过!
于是就开始寻找更好的做法。
有两种更好的方法,参考上面zwl的博客。
首先由于线性基重复是不影响的,所以可以直接按 \(O(1)\) RMQ的方法来做,这样一共最多只需要合并4次线性基,所以就少掉一个log!
另一种做法是用点分治离线处理询问。对于每个重心求出到每个子树中节点的线性基,再给子树标号,判断当前重心上的询问是否通过重心。如果是的话就合并一次线性基得到答案,否则就把询问丢进所在的子树中。
这样总复杂度为\(O(n\log n\log a+q\log ^2a+q\log n)\),分别是点分构建线性基,总询问复杂度和每层扫描询问复杂度。
很奇妙呢!!
代码
于是我写的是开始的正常(不科学)方法。
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long giant;
const int maxn=2e4+1;
const int maxj=15;
const int maxg=61;
int n,m,f[maxn][maxj],dep[maxn];
giant a[maxn];
struct base {
giant a[maxg];
inline void clear() {memset(a,0,sizeof a);}
base () {clear();}
void insert(giant x) {
for (int j=maxg-1;j>=0;--j) if ((x>>j)&1) {
if ((a[j]>>j)&1) x^=a[j]; else {
a[j]=x;
break;
}
}
}
giant mx() {
giant ret=0;
for (int j=maxg-1;j>=0;--j) if ((ret^a[j])>ret) ret^=a[j];
return ret;
}
} b[maxn][maxj];
base operator + (base a,base b) {
for (int j=maxg-1;j>=0;--j) if (b.a[j]) a.insert(b.a[j]);
return a;
}
vector<int> g[maxn];
inline void add(int x,int y) {g[x].push_back(y);}
void dfs(int x,int fa) {
f[x][0]=fa;
dep[x]=dep[fa]+1;
for (int v:g[x]) if (v!=fa) {
b[v][0].insert(a[x]);
dfs(v,x);
}
}
int lca(int x,int y) {
if (dep[x]<dep[y]) swap(x,y);
for (int j=maxj-1;j>=0;--j) if (dep[f[x][j]]>=dep[y]) x=f[x][j];
if (x==y) return x;
for (int j=maxj-1;j>=0;--j) if (f[x][j]!=f[y][j]) x=f[x][j],y=f[y][j];
return f[x][0];
}
base get(int x,int p) {
base ret;
ret.insert(a[x]);
for (int j=maxj-1;j>=0;--j) if (dep[f[x][j]]>=dep[p]) ret=ret+b[x][j],x=f[x][j];
return ret;
}
int main() {
#ifndef ONLINE_JUDGE
freopen("test.in","r",stdin);
#endif
scanf("%d%d",&n,&m);
for (int i=1;i<=n;++i) scanf("%lld",a+i);
for (int i=1;i<n;++i) {
int x,y;
scanf("%d%d",&x,&y);
add(x,y),add(y,x);
}
dfs(1,1);
for (int j=1;j<maxj;++j) for (int i=1;i<=n;++i) {
f[i][j]=f[f[i][j-1]][j-1];
b[i][j]=b[i][j-1]+b[f[i][j-1]][j-1];
}
while (m--) {
int x,y;
scanf("%d%d",&x,&y);
if (x==y) {
printf("%lld\n",a[x]);
continue;
}
if (dep[x]>dep[y]) swap(x,y);
int l=lca(x,y);
base c=(x==l?get(y,l):get(x,l)+get(y,l));
giant ans=c.mx();
printf("%lld\n",ans);
}
return 0;
}