首先第一问的树形换根dp是很显然的。
首先一次dp算出一个点子树内的答案,然后再一次换根把儿子什么的排个序就好了。
考虑第二个怎么做。
我们考虑\(a\)到\(b\)之间的路径,这中间肯定有一条边是不被走到的,然后感性理解一下这个东西具有可二分性。
就是大概要找到一个两边平均的位置。
然后就很好做了。时间复杂度\(O(nlogn)\)
code:
#include<bits/stdc++.h>
#define I inline
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
#define abs(x) ((x)>0?(x):-(x))
#define re register
#define ll long long
#define db double
#define N 300000
#define M 50000
#define mod 1000000000
#define mod2 39989
#define eps (1e-7)
#define U unsigned int
#define it iterator
#define Gc() getchar()
#define Me(x,y) memset(x,y,sizeof(x))
using namespace std;
int n,m,k,F[N+5],dp[N+5],x,y,Bh,G[N+5],Q1[N+5],Q2[N+5],ans=1e9,a,b,l,r,mid,fl[N+5<<1];
struct yyy{int to,z;};
struct ljb{
int head,h[N+5];yyy f[N+5<<1];
I void add(int x,int y){f[head]=(yyy){y,h[x]};h[x]=head++;}
}s;
struct ques{int w,id;}B[N+5];I bool cmp(ques x,ques y){return x.w>y.w;}
I void dfs1(int x,int last){
dp[x]=0;yyy tmp;int i;for(i=s.h[x];~i;i=tmp.z) tmp=s.f[i],!fl[i]&&tmp.to^last&&(dfs1(tmp.to,x),0);
Bh=0;for(i=s.h[x];~i;i=tmp.z) tmp=s.f[i],tmp.to^last&&!fl[i]&&(B[++Bh]=(ques){dp[tmp.to],tmp.to},0);
sort(B+1,B+Bh+1,cmp);for(i=1;i<=Bh;i++) dp[x]=max(dp[x],B[i].w+i);
}
I void dfs2(int x,int last){
yyy tmp;int i;Bh=0;for(i=s.h[x];~i;i=tmp.z) tmp=s.f[i],tmp.to^last&&(B[++Bh]=(ques){dp[tmp.to],tmp.to},0);x^1&&(B[++Bh]=(ques){F[x]-1,x},0);
sort(B+1,B+Bh+1,cmp);for(i=1;i<=Bh;i++) B[i].w+=i,G[x]=max(G[x],B[i].w);ans=min(ans,G[x]);for(i=1;i<=Bh;i++) Q1[i]=max(Q1[i-1],B[i].w);
Q2[Bh+1]=0;for(i=Bh;i;i--) Q2[i]=max(Q2[i+1],B[i].w);for(i=1;i<=Bh;i++) B[i].id^x&&(F[B[i].id]=max(Q1[i-1],Q2[i+1]-1)+1,dfs2(B[i].id,x),0);
}
struct Solve{
int st[N+5<<1],sh,flag;
I void Make(int x,int last){
if(x==b||flag)return (void)(flag=1);yyy tmp;for(int i=s.h[x];~i&&!flag;i=tmp.z) tmp=s.f[i],tmp.to^last&&(st[++sh]=i,Make(tmp.to,x),sh-=flag^1);
}
I int check(int x){
fl[st[x]]=fl[st[x]^1]=1;dfs1(a,0);dfs1(b,0);fl[st[x]]=fl[st[x]^1]=0;return dp[a]>dp[b];
}
I void solve(){
//Make(a,0);int ans=1e18;for(int i=1;i<=sh;i++) ans=min(ans,check(i)),printf("%d\n",check(i));
Make(a,0);l=0;r=sh;while(l+1<r) mid=l+r>>1,(check(mid)?r:l)=mid;check(l);ans=min(ans,max(dp[a],dp[b])); check(r);ans=min(ans,max(dp[a],dp[b])); printf("%d\n",ans);
}
}S1;
int main(){
// freopen("game.in","r",stdin);freopen("game.out","w",stdout);
re int i;Me(s.h,-1);scanf("%d%d%d",&n,&a,&b);for(i=1;i<n;i++) scanf("%d%d",&x,&y),s.add(x,y),s.add(y,x);dfs1(1,0);dfs2(1,0);printf("%d\n",ans);S1.solve();
}