这道题也算是厉害了,改了整整俩小时最后发现是深信的LCA打错了,悲伤啊!信仰崩塌了!
顺便复习LCA,给出模板
void init(){//p[i][j]表示i节点2^j的祖先 int j; for(j=0;(1<<j)<=n;j++) pos(i,1,n) p[i][j]=-1; pos(i,1,n) p[i][0]=fa[i]; for(j=1;(1<<j)<=n;j++) pos(i,1,n) if(p[i][j-1]!=-1) p[i][j]=p[p[i][j-1]][j-1];//i的2^j祖先即为2^(j-1)祖先的2^(j-1)祖先 } int lca(int a,int b){ int i; if(dep[a]<dep[b]) swap(a,b);//深度深的先爬 for(i=0;(1<<i)<=dep[a];i++); i--; for(int j=i;j>=0;j--) if(dep[a]-(1<<j)>=dep[b]) a=p[a][j];//爬到等高 if(a==b) return a;//如果在一条链上,直接返回 for(int j=i;j>=0;j--){ if(p[a][j]!=-1&&p[a][j]!=p[b][j]){ a=p[a][j];b=p[b][j];//两边一起爬,直到爬到LCA } } return fa[a]; }
这道题的题解是这样的:
#include<iostream> #include<algorithm> #include<cstdio> #include<cstring> #include<queue> #include<cmath> #define pos(i,a,b) for(int i=(a);i<=(b);i++) using namespace std; #define N 101000 int n,m; struct haha{ int next,to; }edge[N*2]; int head[N],cnt=1; void add(int u,int v){ edge[cnt].to=v;edge[cnt].next=head[u];head[u]=cnt++; } int dep[N],fa[N],size[N]; int p[N][22]; void dfs(int x){ size[x]=1; for(int i=head[x];i;i=edge[i].next){ int to=edge[i].to; if(fa[x]!=to){ fa[to]=x;dep[to]=dep[x]+1; dfs(to);size[x]+=size[to]; } } } void init(){//p[i][j]表示i节点2^j的祖先 int j; for(j=0;(1<<j)<=n;j++) pos(i,1,n) p[i][j]=-1; pos(i,1,n) p[i][0]=fa[i]; for(j=1;(1<<j)<=n;j++) pos(i,1,n) if(p[i][j-1]!=-1) p[i][j]=p[p[i][j-1]][j-1];//i的2^j祖先即为2^(j-1)祖先的2^(j-1)祖先 } int lca(int a,int b){ int i; if(dep[a]<dep[b]) swap(a,b);//深度深的先爬 for(i=0;(1<<i)<=dep[a];i++); i--; for(int j=i;j>=0;j--) if(dep[a]-(1<<j)>=dep[b]) a=p[a][j];//爬到等高 if(a==b) return a;//如果在一条链上,直接返回 for(int j=i;j>=0;j--){ if(p[a][j]!=-1&&p[a][j]!=p[b][j]){ a=p[a][j];b=p[b][j];//两边一起爬,直到爬到LCA } } return fa[a]; } int get(int x,int anc,int num){ int i; for(i=0;(1<<i)<=dep[x];i++); i--; for(int j=20;j>=0;j--) if(dep[x]-(1<<j)>=dep[anc]+num) x=p[x][j]; return x; } int main(){ //freopen("date.in","r",stdin); //freopen("date.out","w",stdout); scanf("%d",&n); pos(i,1,n-1){ int x,y;scanf("%d%d",&x,&y); add(x,y);add(y,x); } dfs(2);//cout<<fa[2]; init(); scanf("%d",&m);//system("pause"); pos(i,1,m){ int x,y;scanf("%d%d",&x,&y); int anc=lca(x,y); int dis=dep[x]+dep[y]-2*dep[anc];//cout<<dis<<endl; if(dis%2==1){ printf("0\n");continue; } if(x==y){ printf("%d\n",n);continue; } if(dep[x]==dep[y]){ int p1=get(x,anc,1),p2=get(y,anc,1); printf("%d\n",n-size[p1]-size[p2]); continue; } if(dep[x]<dep[y]) swap(x,y); int p1=get(x,x,-dis/2),p2=get(x,p1,1); printf("%d\n",size[p1]-size[p2]); } return 0; }