题目链接:https://vjudge.net/problem/HDU-5977
题意:给一颗树,每个结点上有一个权值a[i],a[i]<=10,求有多少条路径满足这条路径上所有权值的结点都出现了。
思路:
首先利用二进制的思想,将a[i]转化为1<<(a[i]-1)。我们在子树中,计算出结点到重心的路径,用二进制表示,比如011表示该路径中权值3没有出现、权值1和2出现。因为k最大为10,那么我们在计算结果时把所有可能枚举一遍,也就1024,如果枚举的i和当前路径取或后=(1<<k)-1,那么该路径满足要求,加上即可。具体实现时用桶记录信息,mine[i]表示权值为i的路径的个数。
另外,题目规定不同的路径仅当起点终点均不同,所以(1,2)和(2,1)是两个合法解,我的处理是先getdis一遍得到桶的信息,处理子结点,现将该子结点的子树的信息清除掉,即change函数,dfs之后再恢复回来。这种处理很重要,所以把这题当作模板记录一下。
不得不说,写点分治时要非常细心,我总是半小时代码,1小时改bug,老是一些简单错误。
AC代码:
#include<cstdio> #include<algorithm> #include<cstring> using namespace std; typedef long long LL; const int maxn=50005; const int inf=0x3f3f3f3f; struct node{ int v,nex; }edge[maxn<<1]; int n,k,cnt,head[maxn],a[maxn],sz[maxn],mson[maxn],Min,size,root; int vis[maxn],flag; LL ans,mine[1050]; void adde(int u,int v){ edge[++cnt].v=v; edge[cnt].nex=head[u]; head[u]=cnt; } void getroot(int u,int fa){ sz[u]=1,mson[u]=0; for(int i=head[u];i;i=edge[i].nex){ int v=edge[i].v; if(vis[v]||v==fa) continue; getroot(v,u); sz[u]+=sz[v]; mson[u]=max(mson[u],sz[v]); } mson[u]=max(mson[u],size-sz[u]); if(mson[u]<Min) Min=mson[u],root=u; } void getdis(int u,int fa,int len){ ++mine[len]; for(int i=head[u];i;i=edge[i].nex){ int v=edge[i].v; if(vis[v]||v==fa) continue; getdis(v,u,len|a[v]); } } void change(int u,int fa,int len,int f){ mine[len]+=f; for(int i=head[u];i;i=edge[i].nex){ int v=edge[i].v; if(vis[v]||v==fa) continue; change(v,u,len|a[v],f); } } void dfs(int u,int fa,int len){ for(int i=0;i<(1<<k);++i){ if((i|len)!=flag) continue; if(!mine[i]) continue; ans+=mine[i]; } for(int i=head[u];i;i=edge[i].nex){ int v=edge[i].v; if(vis[v]||v==fa) continue; dfs(v,u,len|a[v]); } } void solve(int u){ getdis(u,0,a[u]); for(int i=0;i<(1<<k);++i){ if((i|a[u])!=flag) continue; if(!mine[i]) continue; ans+=mine[i]; } for(int i=head[u];i;i=edge[i].nex){ int v=edge[i].v; if(vis[v]) continue; change(v,u,a[u]|a[v],-1); dfs(v,u,a[u]|a[v]); change(v,u,a[u]|a[v],1); } memset(mine,0,sizeof(mine)); } void fenzhi(int u,int ssize){ vis[u]=1; solve(u); for(int i=head[u];i;i=edge[i].nex){ int v=edge[i].v; if(vis[v]) continue; Min=inf,root=0; size=sz[v]<sz[u]?sz[v]:ssize-sz[u]; getroot(v,0); fenzhi(root,size); } } int main(){ while(~scanf("%d%d",&n,&k)){ cnt=0; ans=0; flag=(1<<k)-1; for(int i=0;i<=n;++i) head[i]=vis[i]=0; memset(mine,0,sizeof(mine)); for(int i=1;i<=n;++i){ scanf("%d",&a[i]); a[i]=1<<(a[i]-1); } for(int i=1;i<n;++i){ int u,v; scanf("%d%d",&u,&v); adde(u,v); adde(v,u); } Min=inf,root=0,size=n; getroot(1,0); fenzhi(root,n); printf("%lld\n",ans); } return 0; }