第一种方法,dfs序上建可持久化线段树,然后询问的时候把两点之间的所有树链扒出来做差。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+,inf=0x3f3f3f3f;
int hd[N],ne,n,n2,m,fa[N],son[N],siz[N],dep[N],top[N],dfn[N],rnk[N],tot,rt[N],ls[N*],rs[N*],val[N*],tot2,a[N],b[N],ql[],qr[],nl,nr;
struct E {int v,nxt;} e[N<<];
void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;}
void dfs1(int u,int f,int d) {
fa[u]=f,son[u]=,siz[u]=,dep[u]=d;
for(int i=hd[u]; ~i; i=e[i].nxt) {
int v=e[i].v;
if(v==fa[u])continue;
dfs1(v,u,d+),siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
void dfs2(int u,int tp) {
top[u]=tp,dfn[u]=++tot,rnk[tot]=u;
if(son[u])dfs2(son[u],tp);
for(int i=hd[u]; ~i; i=e[i].nxt) {
int v=e[i].v;
if(v==fa[u]||v==son[u])continue;
dfs2(v,v);
}
}
#define mid ((l+r)>>1)
void upd(int& u,int v,int x,int l=,int r=n2) {
if(!u)u=++tot2;
val[u]=val[v]+;
if(l==r)return;
if(x<=mid)upd(ls[u],ls[v],x,l,mid),rs[u]=rs[v];
else upd(rs[u],rs[v],x,mid+,r),ls[u]=ls[v];
}
int ask(int u,int v,int k) {
for(nl=nr=; top[u]!=top[v]; u=fa[top[u]]) {
if(dep[top[u]]<dep[top[v]])swap(u,v);
ql[nl++]=rt[dfn[top[u]]-],qr[nr++]=rt[dfn[u]];
}
if(dep[u]<dep[v])swap(u,v);
ql[nl++]=rt[dfn[v]-],qr[nr++]=rt[dfn[u]];
int l=,r=n2;
while(l<r) {
int cnt=;
for(int i=; i<nr; ++i)cnt+=val[ls[qr[i]]];
for(int i=; i<nl; ++i)cnt-=val[ls[ql[i]]];
if(k<=cnt) {
for(int i=; i<nr; ++i)qr[i]=ls[qr[i]];
for(int i=; i<nl; ++i)ql[i]=ls[ql[i]];
r=mid;
} else {
k-=cnt;
for(int i=; i<nr; ++i)qr[i]=rs[qr[i]];
for(int i=; i<nl; ++i)ql[i]=rs[ql[i]];
l=mid+;
}
}
return l;
}
int main() {
memset(hd,-,sizeof hd),ne=;
scanf("%d%d",&n,&m);
for(int i=; i<=n; ++i)scanf("%d",&a[i]);
for(int i=; i<=n; ++i)b[i-]=a[i];
sort(b,b+n),n2=unique(b,b+n)-b;
for(int i=; i<=n; ++i)a[i]=(lower_bound(b,b+n2,a[i])-b)+;
for(int i=; i<n; ++i) {
int u,v;
scanf("%d%d",&u,&v);
addedge(u,v),addedge(v,u);
}
tot=,dfs1(,,),dfs2(,);
memset(rt,,sizeof rt),tot2=;
for(int i=; i<=n; ++i)upd(rt[i],rt[i-],a[rnk[i]],,n2);
for(int last=; m--;) {
int u,v,k;
scanf("%d%d%d",&u,&v,&k),u^=last;
int ans=b[ask(u,v,k)-];
printf("%d\n",ans),last=ans;
}
return ;
}
仔细一想这样似乎麻烦了点。因为没有修改操作,我们可以直接用子结点继承父节点的方式来建线段树,然后查询的时候,用u,v的线段树减去lca的线段树再减去lca父节点的线段树即可。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+,inf=0x3f3f3f3f;
int hd[N],ne,n,n2,m,fa[N],son[N],siz[N],dep[N],top[N],dfn[N],rnk[N],tot,rt[N],ls[N*],rs[N*],val[N*],tot2,a[N],b[N],ql[],qr[],nl,nr;
struct E {int v,nxt;} e[N<<];
void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;}
void dfs1(int u,int f,int d) {
fa[u]=f,son[u]=,siz[u]=,dep[u]=d;
for(int i=hd[u]; ~i; i=e[i].nxt) {
int v=e[i].v;
if(v==fa[u])continue;
dfs1(v,u,d+),siz[u]+=siz[v];
if(siz[v]>siz[son[u]])son[u]=v;
}
}
void dfs2(int u,int tp) {
top[u]=tp,dfn[u]=++tot,rnk[tot]=u;
if(son[u])dfs2(son[u],tp);
for(int i=hd[u]; ~i; i=e[i].nxt) {
int v=e[i].v;
if(v==fa[u]||v==son[u])continue;
dfs2(v,v);
}
}
#define mid ((l+r)>>1)
void upd(int& u,int v,int x,int l=,int r=n2) {
if(!u)u=++tot2;
val[u]=val[v]+;
if(l==r)return;
if(x<=mid)upd(ls[u],ls[v],x,l,mid),rs[u]=rs[v];
else upd(rs[u],rs[v],x,mid+,r),ls[u]=ls[v];
}
void dfs3(int u) {
upd(rt[u],rt[fa[u]],a[u]);
for(int i=hd[u]; ~i; i=e[i].nxt) {
int v=e[i].v;
if(v==fa[u])continue;
dfs3(v);
}
}
int lca(int u,int v) {
for(; top[u]!=top[v]; u=fa[top[u]])if(dep[top[u]]<dep[top[v]])swap(u,v);
return dep[u]<dep[v]?u:v;
}
int ask(int u,int v,int w1,int w2,int k,int l=,int r=n2) {
if(l==r)return l;
int cnt=val[ls[u]]+val[ls[v]]-val[ls[w1]]-val[ls[w2]];
return k<=cnt?ask(ls[u],ls[v],ls[w1],ls[w2],k,l,mid):ask(rs[u],rs[v],rs[w1],rs[w2],k-cnt,mid+,r);
}
int main() {
memset(hd,-,sizeof hd),ne=;
scanf("%d%d",&n,&m);
for(int i=; i<=n; ++i)scanf("%d",&a[i]);
for(int i=; i<=n; ++i)b[i-]=a[i];
sort(b,b+n),n2=unique(b,b+n)-b;
for(int i=; i<=n; ++i)a[i]=(lower_bound(b,b+n2,a[i])-b)+;
for(int i=; i<n; ++i) {
int u,v;
scanf("%d%d",&u,&v);
addedge(u,v),addedge(v,u);
}
tot=,dfs1(,,),dfs2(,);
memset(rt,,sizeof rt),tot2=;
dfs3();
for(int last=; m--;) {
int u,v,k;
scanf("%d%d%d",&u,&v,&k),u^=last;
int w=lca(u,v);
int ans=b[ask(rt[u],rt[v],rt[w],rt[fa[w]],k)-];
printf("%d\n",ans),last=ans;
}
return ;
}
然后我又测试了倍增和RMQ求LCA的方法,发现居然还不如dfs序+可持久化线段树的方法快~~毕竟倍增和RMQ预处理的时间和空间复杂度都是$O(nlogn)$,而树剖只需要$O(n)$,而且查询速度也比较快。
倍增:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+,inf=0x3f3f3f3f;
int hd[N],ne,n,n2,m,fa[N][],dep[N],rt[N],ls[N*],rs[N*],val[N*],tot2,a[N],b[N];
struct E {int v,nxt;} e[N<<];
void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;}
#define mid ((l+r)>>1)
void upd(int& u,int v,int x,int l=,int r=n2) {
if(!u)u=++tot2;
val[u]=val[v]+;
if(l==r)return;
if(x<=mid)upd(ls[u],ls[v],x,l,mid),rs[u]=rs[v];
else upd(rs[u],rs[v],x,mid+,r),ls[u]=ls[v];
}
void dfs(int u,int f,int d) {
fa[u][]=f,dep[u]=d,upd(rt[u],rt[fa[u][]],a[u]);
for(int i=; i<; ++i)fa[u][i]=fa[fa[u][i-]][i-];
for(int i=hd[u]; ~i; i=e[i].nxt) {
int v=e[i].v;
if(v==fa[u][])continue;
dfs(v,u,d+);
}
}
int lca(int u,int v) {
if(dep[u]<dep[v])swap(u,v);
for(int i=; dep[u]!=dep[v]; --i)if(dep[fa[u][i]]>=dep[v])u=fa[u][i];
if(u==v)return u;
for(int i=; i>=; --i)if(fa[u][i]!=fa[v][i])u=fa[u][i],v=fa[v][i];
return fa[u][];
}
int ask(int u,int v,int w1,int w2,int k,int l=,int r=n2) {
if(l==r)return l;
int cnt=val[ls[u]]+val[ls[v]]-val[ls[w1]]-val[ls[w2]];
return k<=cnt?ask(ls[u],ls[v],ls[w1],ls[w2],k,l,mid):ask(rs[u],rs[v],rs[w1],rs[w2],k-cnt,mid+,r);
}
int main() {
memset(hd,-,sizeof hd),ne=;
scanf("%d%d",&n,&m);
for(int i=; i<=n; ++i)scanf("%d",&a[i]);
for(int i=; i<=n; ++i)b[i-]=a[i];
sort(b,b+n),n2=unique(b,b+n)-b;
for(int i=; i<=n; ++i)a[i]=(lower_bound(b,b+n2,a[i])-b)+;
for(int i=; i<n; ++i) {
int u,v;
scanf("%d%d",&u,&v);
addedge(u,v),addedge(v,u);
}
memset(rt,,sizeof rt),tot2=;
dfs(,,);
for(int last=; m--;) {
int u,v,k;
scanf("%d%d%d",&u,&v,&k),u^=last;
int w=lca(u,v);
int ans=b[ask(rt[u],rt[v],rt[w],rt[fa[w][]],k)-];
printf("%d\n",ans),last=ans;
}
return ;
}
RMQ:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+,inf=0x3f3f3f3f;
int hd[N],ne,n,n2,m,fa[N],dep[N],pos[N],ST[N<<][],Log[N<<],tot,rt[N],ls[N*],rs[N*],val[N*],tot2,a[N],b[N];
struct E {int v,nxt;} e[N<<];
void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;}
#define mid ((l+r)>>1)
void upd(int& u,int v,int x,int l=,int r=n2) {
if(!u)u=++tot2;
val[u]=val[v]+;
if(l==r)return;
if(x<=mid)upd(ls[u],ls[v],x,l,mid),rs[u]=rs[v];
else upd(rs[u],rs[v],x,mid+,r),ls[u]=ls[v];
}
void dfs(int u,int f,int d) {
fa[u]=f,dep[u]=d,ST[++tot][]=u,pos[u]=tot,upd(rt[u],rt[fa[u]],a[u]);
for(int i=hd[u]; ~i; i=e[i].nxt) {
int v=e[i].v;
if(v==fa[u])continue;
dfs(v,u,d+),ST[++tot][]=u;
}
}
bool cmp(int a,int b) {return dep[a]<dep[b];}
void initST() {
for(int j=; j<; ++j)
for(int i=; i+(<<j)-<=tot; ++i)
ST[i][j]=min(ST[i][j-],ST[i+(<<(j-))][j-],cmp);
}
int lca(int u,int v) {
int l=pos[u],r=pos[v];
if(l>r)swap(l,r);
int i=Log[r-l+];
return min(ST[l][i],ST[r-(<<i)+][i],cmp);
}
int ask(int u,int v,int w1,int w2,int k,int l=,int r=n2) {
if(l==r)return l;
int cnt=val[ls[u]]+val[ls[v]]-val[ls[w1]]-val[ls[w2]];
return k<=cnt?ask(ls[u],ls[v],ls[w1],ls[w2],k,l,mid):ask(rs[u],rs[v],rs[w1],rs[w2],k-cnt,mid+,r);
}
int main() {
for(int i=; i<(N<<); ++i)Log[i]=log2(i+0.5);
memset(hd,-,sizeof hd),ne=;
scanf("%d%d",&n,&m);
for(int i=; i<=n; ++i)scanf("%d",&a[i]);
for(int i=; i<=n; ++i)b[i-]=a[i];
sort(b,b+n),n2=unique(b,b+n)-b;
for(int i=; i<=n; ++i)a[i]=(lower_bound(b,b+n2,a[i])-b)+;
for(int i=; i<n; ++i) {
int u,v;
scanf("%d%d",&u,&v);
addedge(u,v),addedge(v,u);
}
memset(rt,,sizeof rt),tot2=;
dfs(,,),initST();
for(int last=; m--;) {
int u,v,k;
scanf("%d%d%d",&u,&v,&k),u^=last;
int w=lca(u,v);
int ans=b[ask(rt[u],rt[v],rt[w],rt[fa[w]],k)-];
printf("%d\n",ans),last=ans;
}
return ;
}