\(1\le n,m\le 2\times 10^5\)。
考虑一个简单直接的建图,若将 \(a\) 颜色作为树根需要将 \(b\) 颜色变为 \(a\),从 \(a\) 向 \(b\) 连边。
枚举每个颜色 \(a\) 的所有点,用类似虚树的方法找出所有 \(a\) 颜色依赖的颜色并连边。连边可以用倍增优化建图。倍增优化建图和倍增是一样的,看代码就懂了。
建图完成后跑 tarjan,找出那些没有入边的强连通分量更新答案即可。
#include <cstdio>
#include <vector>
inline int min(const int x, const int y) {return x < y ? x : y;}
struct Edge {int to, nxt;} e[30000005];
int head[5000005], col[5000005], fa[200005][20], f[200005][20], dep[200005], ndtot, tot;
int dfn[5000005], low[5000005], s[5000005], belong[5000005], top, cnt, scc, n, m, ans = 1e9;
bool Instack[5000005];
inline void AddEdge(int u, int v) {e[++ tot].to = v, e[tot].nxt = head[u], head[u] = tot;}
std::vector<int> G[200005], vec[200005], tmp;
void dfs(int u) {
if (u != 1) {
dep[u] = dep[fa[u][0]] + 1, AddEdge(f[u][0] = ++ ndtot, col[fa[u][0]]);
for (int i = 1; i <= 18; ++ i) {
if (!(fa[u][i] = fa[fa[u][i - 1]][i - 1])) continue;
AddEdge(f[u][i] = ++ ndtot, f[u][i - 1]);
AddEdge(f[u][i], f[fa[u][i - 1]][i - 1]);
}
}
for (int v : G[u]) if (v != fa[u][0]) fa[v][0] = u, dfs(v);
}
int LCA(int u, int v) {
if (dep[u] < dep[v]) u ^= v ^= u ^= v;
int t = dep[u] - dep[v], c = col[u];
for (int i = 0; i <= 18; ++ i)
if (t & 1 << i) AddEdge(c, f[u][i]), u = fa[u][i];
if (u == v) return u;
for (int i = 18; i >= 0; -- i)
if (fa[u][i] != fa[v][i]) AddEdge(c, f[u][i]), AddEdge(c, f[v][i]), u = fa[u][i], v = fa[v][i];
AddEdge(c, f[u][0]), AddEdge(c, f[v][0]);
return fa[u][0];
}
void Tarjan(int u) {
dfn[u] = low[u] = ++ cnt, s[++ top] = u, Instack[u] = true;
for (int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if (!dfn[v]) Tarjan(v), low[u] = min(low[u], low[v]);
else if (Instack[v]) low[u] = min(low[u], dfn[v]);
}
if (dfn[u] == low[u]) {
tmp.clear();
int sum = 0;
++ scc;
do {
Instack[s[top]] = false, sum += (s[top] <= m);
belong[s[top]] = scc, tmp.push_back(s[top]);
} while (s[top --] != u);
bool flag = true;
for (int v : tmp) {
for (int i = head[v]; i; i = e[i].nxt)
if (belong[e[i].to] != scc) {flag = false; break;}
if (!flag) break;
}
if (flag) ans = min(ans, sum - 1);
}
}
int main() {
scanf("%d%d", &n, &m);
ndtot = m;
for (int i = 1, u, v; i < n; ++ i) scanf("%d%d", &u, &v), G[u].push_back(v), G[v].push_back(u);
for (int i = 1; i <= n; ++ i) scanf("%d", col + i), vec[col[i]].push_back(i);
dfs(1);
for (int i = 1; i <= m; ++ i) {
int lca = vec[i][0];
for (int j = 1; j < vec[i].size(); ++ j) lca = LCA(lca, vec[i][j]);
}
for (int i = 1; i <= m; ++ i) if (!dfn[i]) Tarjan(i);
printf("%d", ans);
return 0;
}