考虑第一问的部分分。显然设f[i]为i子树从根开始扩展的所需步数,考虑根节点的扩展顺序,显然应该按儿子子树所需步数从大到小进行扩展,将其排序即可。
要做到n=3e5,考虑换根dp。计算某点答案时先将其在父亲中的贡献去掉,然后用和之前同样的方法做即可。冷静一下也没什么复杂的。
第二问注意到两个点扩展出来的点集是不相交的,枚举一条断边,就可以做到n2logn。显然断边的位置可以二分。就是nlog2n了。
#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
#define ll long long
#define N 300010
char getc(){char c=getchar();while ((c<'A'||c>'Z')&&(c<'a'||c>'z')&&(c<'0'||c>'9')) c=getchar();return c;}
int gcd(int n,int m){return m==0?n:gcd(m,n%m);}
int read()
{
int x=0,f=1;char c=getchar();
while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();}
while (c>='0'&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=getchar();
return x*f;
}
int n,a,b,p[N],f[N],pos[N],fa[N],id[N],e[N],t,ans=N,root;
struct data2
{
int x,y;
bool operator <(const data2&a) const
{
return x<a.x;
}
}q[N];
vector<data2> Q[N];
vector<int> pre[N],suf[N];
bool flag[N<<1];
struct data{int to,nxt;
}edge[N<<1];
void addedge(int x,int y){t++;edge[t].to=y,edge[t].nxt=p[x],p[x]=t;}
void dfs(int k,int from)
{
for (int i=p[k];i;i=edge[i].nxt)
if (edge[i].to!=from&&!flag[i]) dfs(edge[i].to,k);
int cnt=0;
for (int i=p[k];i;i=edge[i].nxt)
if (edge[i].to!=from&&!flag[i]) q[++cnt].x=f[edge[i].to],q[cnt].y=edge[i].to;
sort(q+1,q+cnt+1);reverse(q+1,q+cnt+1);
if (k==1) for (int i=1;i<=cnt;i++) Q[1].push_back((data2){q[i].x+i,q[i].y}),pos[q[i].y]=i-1;
f[k]=0;for (int i=1;i<=cnt;i++) f[k]=max(f[k],q[i].x+i);
}
int calc(int root){dfs(root,root);return f[root];}
void getans(int k,int from)
{
ans=min(ans,f[k]);
int s=Q[k].size();
pre[k].push_back(Q[k][0].x);
for (int j=1;j<s;j++) pre[k].push_back(Q[k][j].x);
for (int j=1;j<s;j++) pre[k][j]=max(pre[k][j],pre[k][j-1]);
if (s)
{
suf[k].push_back(Q[k][s-1].x);
for (int j=s-2;j>=0;j--) suf[k].push_back(Q[k][j].x);
for (int j=1;j<s;j++) suf[k][j]=max(suf[k][j],suf[k][j-1]);
reverse(suf[k].begin(),suf[k].end());
}
for (int i=p[k];i;i=edge[i].nxt)
if (edge[i].to!=from)
{
int cnt=0,x=edge[i].to,tmp=f[k];
//f[k]去掉edge[i].to后的答案
f[k]=0;
if (pos[x]>0) f[k]=pre[k][pos[x]-1];
if (pos[x]+1<s) f[k]=max(f[k],suf[k][pos[x]+1]-1);
for (int j=p[x];j;j=edge[j].nxt)
q[++cnt].x=f[edge[j].to],q[cnt].y=edge[j].to;
sort(q+1,q+cnt+1);reverse(q+1,q+cnt+1);
f[x]=0;for (int j=1;j<=cnt;j++) f[x]=max(f[x],q[j].x+j),Q[x].push_back((data2){q[j].x+j,q[j].y}),pos[q[j].y]=j-1;
f[k]=tmp;
getans(x,k);
}
}
void getfa(int k)
{
for (int i=p[k];i;i=edge[i].nxt)
if (edge[i].to!=fa[k])
{
fa[edge[i].to]=k;
e[edge[i].to]=i+(i&1);
getfa(edge[i].to);
}
}
int main()
{
n=read(),a=read(),b=read();
for (int i=1;i<n;i++)
{
int x=read(),y=read();
addedge(x,y),addedge(y,x);
}
dfs(1,1);
getans(1,1);
cout<<ans<<endl;ans=N;
getfa(a);int cnt=0;
int x=b;while (x!=a) id[++cnt]=x,x=fa[x];
int l=1,r=cnt;
while (l<=r)
{
int mid=l+r>>1;int x=id[mid];
flag[e[x]]=1,flag[e[x]-1]=1;
int u=calc(a),v=calc(b);
flag[e[x]]=0,flag[e[x]-1]=0;
ans=min(ans,max(u,v));
if (u>=v) l=mid+1;else r=mid-1;
}
cout<<ans;
return 0;
}