Code:
#include<bits/stdc++.h> #define setIO(s) freopen(s".in","r",stdin) #define maxn 1000000 #define inf 1000000000 using namespace std; int n,m,edges; int col[maxn],f[maxn][2],hd[maxn],to[maxn<<1],nex[maxn<<1]; void addedge(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; } void dfs(int u,int ff) { if(u<=n) return; for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(v==ff) continue; dfs(v, u); f[u][0]+=min(f[v][0]-1,f[v][1]); f[u][1]+=min(f[v][1]-1,f[v][0]); } } int main() { // setIO("input"); int i,j; scanf("%d%d",&m,&n); for(i=1;i<=m;++i) f[i][0]=f[i][1]=1; for(i=1;i<=n;++i) scanf("%d",&col[i]), f[i][col[i]^1]=inf; for(i=1;i<m;++i) { int a,b; scanf("%d%d",&a,&b); addedge(a,b), addedge(b,a); } dfs(n+1,0); printf("%d\n",min(f[n+1][0], f[n+1][1])); return 0; }