学了好久(一两个星期)都没彻底搞懂的lca,今天总算理解了。就来和大家分享下我自己的心得
首先,如果你还不懂什么是lca,出门左转自行百度
首先讲倍增
倍增的思想很简单,首先进行预处理,用一个深搜将每个点的深度和它向上跳一步到达的点(也就是它的父节点)处理出来,然后用下列递推式
f[i][j]=f[f[i][j-1]][j-1]
求出该点跳2^j步所到达的点。这里解释一下,为什么是f[f[i][]j-1][j-1]?因为倍增每次都是跳的2的整数次幂步,而2^j=2^(j-1)+2^(j-1);这样就不难理解了。
然后,对于每两个询问的点,只需要先找出那个点的深度更深,就将它跳跃到与另一个点深度相同,如果此时两个点相同,那么这个点就是最近公共祖先;如果不相同,两个点就一起跳,直找到最近公共祖先为止。
上代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<cstdlib>
#include<algorithm>
#define N 500005
using namespace std;
int n,m,s,d[N],f[N][],head[N];
struct Edge{
int from,to,next;
}edge[N*];
inline int read()
{
char ch=getchar();int num=;
if(ch<''||ch>'')ch=getchar();
while(ch>=''&&ch<='')
{num=num*+ch-'';
ch=getchar();}
return num;
}
int anum=;
void add(int x,int y)
{edge[anum].to=y;
edge[anum].next=head[x];
head[x]=anum++;}
void dfs(int u)
{
for(int i=head[u];i!=-;i=edge[i].next)
{
int ne=edge[i].to;
if(d[ne]==)
{d[ne]=d[u]+;
f[ne][]=u;
dfs(ne);}
}
}
void init()
{
for(int i=;i<=;i++)
{
for(int j=;j<=n;j++)
{f[j][i]=f[f[j][i-]][i-];}
}
}
int lca(int a,int b)
{
if(d[a]<d[b]) swap(a,b);
for(int i=;i>=;i--)
{if(d[f[a][i]]>=d[b])
a=f[a][i];}
if(a==b) return a;
for(int i=;i>=;i--)
if(f[a][i]!=f[b][i])
a=f[a][i],b=f[b][i];
return f[a][];
}
int main()
{
memset(head,-,sizeof(head));
n=read();m=read();s=read();
for(int i=;i<n;i++)
{int x,y;
x=read();y=read();
add(x,y);add(y,x);}
d[s]=;
dfs(s);init();
for(int i=;i<=m;i++)
{int a,b;
a=read();b=read();
printf("%d\n",lca(a,b));}
return ;
}
关于tarjan,具体思想我在另外一篇博客中已经讲过了,这里就只放代码,思路请转:这里
下面是代码:
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<cstdlib>
#define N 500005
#define M 1000001
using namespace std;
int n,m,s,cnt1,cnt2;
int dad[N],ans[N];
bool used[N];
struct edge{
int v,num,next;
}e1[M],e2[M];
struct road{
int head;
}v1[N],v2[N];
void before()
{
memset(v1,-,sizeof(v1));
memset(v2,-,sizeof(v2));
memset(used,false,sizeof(used));
memset(dad,-,sizeof(dad));
}
int find(int x)
{
return dad[x]==-?x:dad[x]=find(dad[x]);
}
void together(int x,int y)
{
int f1=find(x);
int f2=find(y);
if(f1!=f2)
dad[y]=x;
}
void v1add(int x,int y)
{
e1[cnt1].v=y;
e1[cnt1].next=v1[x].head;
v1[x].head=cnt1++;
}
void v2add(int x,int y,int z)
{
e2[cnt2].v=y;
e2[cnt2].num=z;
e2[cnt2].next=v2[x].head;
v2[x].head=cnt2++;
}
void tarjan(int u)
{
used[u]=true;
for(int i=v1[u].head;i!=-;i=e1[i].next)
{
int v=e1[i].v;
if(used[v]) continue;
tarjan(v);
together(u,v);
}
int sum;
for(int i=v2[u].head;i!=-;i=e2[i].next)
{
int v=e2[i].v;
sum=e2[i].num;
if(used[v])
ans[sum]=find(v);
}
}
int main()
{
int u,v;
scanf("%d%d%d",&n,&m,&s);
before();
int nn=n;
nn--;
while(nn--)
{
scanf("%d%d",&u,&v);
v1add(v,u);v1add(u,v);
}
for(int i=;i<=m;i++)
{
scanf("%d%d",&u,&v);
v2add(u,v,i);v2add(v,u,i);
}
tarjan(s);
for(int i=;i<=m;i++)
printf("%d\n",ans[i]);
return ;
}