问题描述
在 JOI 的国度有 N 个小镇,从 1 到 N 编号,并由 N−1 条双向道路连接。第 i 条道路连接了 Ai 和 Bi 这两个编号的小镇。
这个国家的国王现将整个国家分为 K 个城市,从 1 到 K 编号,每个城市都有附属的小镇,其中编号为 j 的小镇属于编号为 Cj 的城市。每个城市至少有一个附属小镇。
国王还要选定一个首都。首都的条件是该城市的任意小镇都只能通过属于该城市的小镇到达。
但是现在可能不存在这样的选址,所以国王还需要将一些城市进行合并。对于合并城市 x 和 y ,指的是将所有属于 y 的小镇划归给 x 城。
你需要求出最少的合并次数。
输入格式
输入第一行两个整数 N,K,为小镇和城市的数量。
接下来的 N−1 行,每行两个整数 Ai,Bi,描述了 N−1 条道路。
再接下来的 N 行,每行一个整数 Cj,表示编号为 j 的小镇属于编号为 Cj 的城市。
输出格式
输出一行一个整数为最少的合并次数。
样例输入
6 3
2 1
3 5
6 2
3 4
2 3
1
3
1
2
3
2
样例输出
1
解析
我们先单独考虑某一种颜色。如果在首都中要包括这种颜色的城镇,那么两个该种颜色的点之间的所有点都必须合并到一起。将其转化为图上的关系,设颜色 i 向颜色 j 连边表示将颜色 i 的城镇合并到颜色 j 中。因此,我们不妨将每一种颜色的虚树建出来,然后就可以方便地在原树上寻找相同颜色之间的点了。
然而,这样连边是 \(n^2\) 的。考虑到每次连边的对象都是树上的一段路径,我们可以用树链剖分优化连边。连边的方式与线段树优化连边类似。为了实现从颜色向颜色连边,我们让每一种颜色向线段树的叶子节点(实际是原树上的节点)连边。对于每一棵虚树,我们用树链剖分找到虚树上两点之间对应的路径,然后从路径对应的区间向虚树的颜色连边即可。
这样,我们得到了一张有向图。对于构成环的合并关系,我们可以将其直接合并,相当于缩点,将原图转化为一个DAG。考虑入度为0的点,这样的点一定是不需要额外合并的,可以作为首都。因此,答案就是DAG上最小的入度为0的点。
代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
#define N 200002
using namespace std;
vector<int> v[N],c[N],a;
int head[N*8],ver[N*16],nxt[N*16],l;
int n,k,i,j,col[N],dep[N],son[N],size[N],fa[N],top[N],pos[N],in[N],out[N],cnt,tot;
int dfn[N*8],low[N*8],s[N*8],T,tim,sccno[N*8],num[N*8],deg[N*8];
int read()
{
char c=getchar();
int w=0;
while(c<'0'||c>'9') c=getchar();
while(c<='9'&&c>='0'){
w=w*10+c-'0';
c=getchar();
}
return w;
}
void insert(int x,int y)
{
l++;
ver[l]=y;
nxt[l]=head[x];
head[x]=l;
}
void dfs1(int x,int pre)
{
fa[x]=pre;
dep[x]=dep[pre]+1;
size[x]=1;
for(int i=head[x];i;i=nxt[i]){
int y=ver[i];
if(y!=pre){
dfs1(y,x);
size[x]+=size[y];
if(size[y]>size[son[x]]) son[x]=y;
}
}
}
void dfs2(int x,int t)
{
top[x]=t;
in[x]=++cnt;
pos[cnt]=x;
if(son[x]) dfs2(son[x],t);
for(int i=head[x];i;i=nxt[i]){
int y=ver[i];
if(y!=fa[x]&&y!=son[x]) dfs2(y,y);
}
}
void build(int p,int l,int r)
{
tot++;
if(l==r){
insert(p+k,col[pos[l]]);
return;
}
int mid=(l+r)/2;
insert(p+k,p*2+k);insert(p+k,p*2+1+k);
build(p*2,l,mid);build(p*2+1,mid+1,r);
}
void link(int p,int l,int r,int ql,int qr,int c)
{
if(ql<=l&&r<=qr){
insert(c,p+k);
return;
}
int mid=(l+r)/2;
if(ql<=mid) link(p*2,l,mid,ql,qr,c);
if(qr>mid) link(p*2+1,mid+1,r,ql,qr,c);
}
int LCA(int u,int v)
{
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u,v);
u=fa[top[u]];
}
if(dep[u]>dep[v]) swap(u,v);
return u;
}
void split(int c,int u,int v)
{
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u,v);
link(1,1,n,in[top[u]],in[u],c);
u=fa[top[u]];
}
if(dep[u]>dep[v]) swap(u,v);
link(1,1,n,in[u],in[v],c);
}
int my_comp(const int &x,const int &y)
{
return in[x]<in[y];
}
void dfs(int c,int x)
{
for(int i=0;i<v[x].size();i++){
split(c,x,v[x][i]);
dfs(c,v[x][i]);
}
}
void Tarjan(int x)
{
dfn[x]=low[x]=++tim;
s[++T]=x;
for(int i=head[x];i;i=nxt[i]){
int y=ver[i];
if(!dfn[y]){
Tarjan(y);
low[x]=min(low[x],low[y]);
}
else if(!sccno[y]) low[x]=min(low[x],dfn[y]);
}
if(dfn[x]==low[x]){
cnt++;
while(1){
int y=s[T--];
sccno[y]=cnt;
if(y<=k) num[cnt]++;
if(y==x) break;
}
}
}
int main()
{
n=read();k=read();
tot=k;
for(i=1;i<n;i++){
int u=read(),v=read();
insert(u,v);
insert(v,u);
}
for(i=1;i<=n;i++){
col[i]=read();
c[col[i]].push_back(i);
}
dfs1(1,0);dfs2(1,1);
memset(head,0,sizeof(head));l=0;
build(1,1,n);
for(i=1;i<=k;i++){
if(!c[i].size()) continue;
a.clear();
sort(c[i].begin(),c[i].end(),my_comp);
s[1]=T=1;
a.push_back(1);
for(j=0;j<c[i].size();j++){
if(c[i][j]==1) continue;
int x=c[i][j],lca=LCA(x,s[T]);
if(lca==s[T]){
s[++T]=x;
a.push_back(x);
continue;
}
while(T>1&&in[s[T-1]]>=in[lca]) v[s[T-1]].push_back(s[T]),T--;
if(lca!=s[T]) v[lca].push_back(s[T]),s[T]=lca,a.push_back(lca);
s[++T]=x;
a.push_back(x);
}
while(T>1) v[s[T-1]].push_back(s[T]),T--;
sort(a.begin(),a.end(),my_comp);
if(v[1].size()==1&&col[1]!=i) dfs(i,a[1]);
else dfs(i,1);
for(j=0;j<a.size();j++) v[a[j]].clear();
}
T=cnt=0;
for(i=1;i<=tot;i++){
if(!dfn[i]) Tarjan(i);
}
for(i=1;i<=tot;i++){
for(j=head[i];j;j=nxt[j]){
if(sccno[i]!=sccno[ver[j]]) deg[sccno[i]]++;
}
}
int ans=1<<30;
for(i=1;i<=cnt;i++){
if(deg[i]==0&&num[i]!=0) ans=min(ans,num[i]-1);
}
printf("%d\n",ans);
return 0;
}