题目描述
JYY有两棵树 A 和 B :树 A 有 N 个点,编号为 1 到 N ;树 B 有N+1 个节点,编号为 1 到N+1
JYY 知道树 B 恰好是由树 A 加上一个叶节点,然后将节点的编号打乱后得到的。他想知道,这个多余的叶子到底是树 B 中的哪一个叶节点呢?
输入格式
输入一行包含一个正整数 N。 接下来 N−1 行,描述树A,每行包含两个整数表示树 A 中的一条边; 接下来 N 行,描述树 B,每行包含两个整数表示树 B 中的一条边。
输出格式
输出一行一个整数,表示树 B 中相比树 A 多余的那个叶子的编号。如果有多个符合要求的叶子,输出 B 中编号最小的那一个的编号。
输入输出样例
输入
5
1 2
2 3
1 4
1 5
1 2
2 3
3 4
4 5
3 6
输出
1
说明/提示
对于所有数据,\(1 \leq n \leq 10 ^ 5\)
求出A树中所有点为根的树哈希值,并排序!
求出B树中所有点为根的树哈希值,从小到大枚举每一个叶子,如果它连接的点的哈希值减去这个叶子对它的贡献(为2)可以在A树所有哈希值中找到则输出,并结束程序!
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+10;
struct edge
{
int u,v,nxt;
}e[maxn<<1],ee[maxn<<1];
int head[maxn],js,headd[maxn],jss;
void addage(int u,int v)
{
e[++js].u=u;e[js].v=v;
e[js].nxt=head[u];head[u]=js;
}
void addagee(int u,int v)
{
ee[++jss].u=u;ee[jss].v=v;
ee[jss].nxt=headd[u];headd[u]=jss;
}
int n;
bool noprime[(int)2e6+10];
int prime[maxn],cnt;
void getprime()
{
for(int i=2;i<2e6+10;++i)
{
if(cnt>n)return;
if(!noprime[i])prime[++cnt]=i;
for(int j=1;j<=cnt;++j)
{
if(i*prime[j]>2e6)break;
noprime[i*prime[j]]=1;
if(i%prime[j]==0)break;
}
}
}
int f[maxn],siz[maxn],g[maxn];
void dfs(int u,int fa)
{
f[u]=siz[u]=1;
for(int i=head[u];i;i=e[i].nxt)
{
int v=e[i].v;
if(v==fa)continue;
dfs(v,u);
siz[u]+=siz[v];
f[u]+=f[v]*prime[siz[v]];
}
}
void getg(int u,int fa,int faf)
{
for(int i=head[u];i;i=e[i].nxt)
{
int v=e[i].v;
if(v==fa)continue;
int tp=f[u]-f[v]*prime[siz[v]]+faf*prime[n-siz[u]];
g[v]=f[v]+tp*prime[n-siz[v]];
getg(v,u,tp);
}
}
int ff[maxn],sizz[maxn],gg[maxn];
void dfss(int u,int fa)
{
ff[u]=sizz[u]=1;
for(int i=headd[u];i;i=ee[i].nxt)
{
int v=ee[i].v;
if(v==fa)continue;
dfss(v,u);
sizz[u]+=sizz[v];
ff[u]+=ff[v]*prime[sizz[v]];
}
}
void getgg(int u,int fa,int faf)
{
for(int i=headd[u];i;i=ee[i].nxt)
{
int v=ee[i].v;
if(v==fa)continue;
int tp=ff[u]-ff[v]*prime[sizz[v]]+faf*prime[n+1-sizz[u]];
gg[v]=ff[v]+tp*prime[n+1-sizz[v]];
getgg(v,u,tp);
}
}
int du[maxn],mnd[maxn];
int main()
{
scanf("%d",&n);
getprime();
for(int u,v,i=1;i<n;++i)
{
scanf("%d%d",&u,&v);
addage(u,v);addage(v,u);
}
for(int u,v,i=1;i<=n;++i)
{
scanf("%d%d",&u,&v);
addagee(u,v);addagee(v,u);
++du[u];++du[v];
}
dfs(1,0);
g[1]=f[1];
getg(1,0,0);
sort(g+1,g+1+n);
dfss(1,0);
gg[1]=ff[1];
getgg(1,0,0);
for(int u=1;u<n+1;++u)
{
if(du[u]!=1)continue;
int i=headd[u];
int w=gg[ee[i].v]-2;
int tp=lower_bound(g+1,g+1+n,w)-g;
if(tp>n)continue;
if(g[tp]==w)
{
printf("%d",u);
return 0;
}
}
return 0;
}