首先考虑枚举lca的做法,对于每一个lca枚举其子树中所有节点,时间复杂度$O(n^2)$显然过不了
再思考发现这是一个针对子树的询问操作,考虑dsu on tree来统计答案
开一个新数组vec[x],其中x为权值,记录了所有权值为x的编号
那么只需要每次计算一颗新子树时,先累加答案,再更新vec数组(如果同时进行则可能会出现同一个子树间的节点被统计,然后他们的lca并不是我们枚举的lca)
但是这样时间复杂度反而变成了$O(n^2 logn)$
怎么办呢?考虑拆位。由于异或时每一位之间的运算是互相独立的(不存在进位借位等操作),所以我们对于每一位单独统计,就不用枚举编号啦(或者说枚举vec数组)
我们记数组cnt[x][0/1],表示在编号第bit位为0/1时,权值为x的子树中的节点数量
那么$ans = \sum $cnt[a[x]^a[lca]][(x>>bit)&1^1]
这样的时间复杂度是$O(n \ logn \ logm)$的
我的统计方法是在主函数中枚举bit,每次进行一次dsu on tree
看了一眼某乎,貌似我这样的写法可能会T,更快的做法是扩充cnt,记录cnt[x][bit][0/1],在每次dsu的时候把全部算完...
#include<bits/stdc++.h> using namespace std; inline int read(){ int ans=0,f=1;char chr=getchar(); while(!isdigit(chr)){if(chr=='-')f=-1;chr=getchar();} while(isdigit(chr))ans=(ans<<3)+(ans<<1)+chr-48,chr=getchar(); return ans*f; }const int M = 2e5+5; int n,m,a[M],cnt[(1<<20)+10000][2],tot,head[M],nxt[M],ver[M],sz[M],son[M],T,bit,b[M],ans,id[M],maxn; inline void add(int x,int y){ver[++tot]=y,nxt[tot]=head[x],head[x]=tot;} void dfs(int x,int fa){ sz[x]=1; for(int i=head[x];i;i=nxt[i]) if(ver[i]!=fa) dfs(ver[i],x), sz[x]+=sz[ver[i]], (sz[ver[i]]>sz[son[x]])&&(son[x]=ver[i]); } inline void Clear(){ for(register int i=1;i<=T;++i) --cnt[b[i]][(id[i]>>bit)&1]; T=0; } inline void Insert(int x,int lca){ b[++T]=a[x],id[T]=x; if(x!=lca)ans+=cnt[a[x]^a[lca]][((x>>bit)&1)^1]; } inline void Re_Add(int x,int fa){ ++cnt[a[x]][(x>>bit)&1]; for(register int i=head[x];i;i=nxt[i]) if(ver[i]!=fa)Re_Add(ver[i],x); } void Add_ans(int x,int fa,int lca){ Insert(x,lca); for(register int i=head[x];i;i=nxt[i]) if(ver[i]!=fa)Add_ans(ver[i],x,lca); } void dsu(int x,int fa){ for(int i=head[x];i;i=nxt[i]) if(ver[i]!=son[x]&&ver[i]!=fa)dsu(ver[i],x),Clear(); if(son[x])dsu(son[x],x); ++cnt[b[++T]=a[x]][(x>>bit)&1],id[T]=x; for(int i=head[x];i;i=nxt[i]) if(ver[i]!=fa&&ver[i]!=son[x])Add_ans(ver[i],x,x),Re_Add(ver[i],x); } signed main(){ n=read(); for(int i=1;i<=n;i++)a[i]=read(),maxn=max(maxn,a[i]); for(int i=1,x,y;i<n;i++)add(x=read(),y=read()),add(y,x); dfs(1,0); long long res=0; for(bit=0;(1<<bit)<=maxn;++bit){ ans=0,dsu(1,0),Clear(); res+=ans*(1ll<<bit); }cout<<res; return 0; }