一棵树上有K个黑色节点,剩余节点都为白色,将其划分成K个子树,使得每棵树上都只有1个黑色节点,共有多少种划分方案。
个人感觉这题比较难。假设dp(i,0..1)代表的是以i为根节点的子树种有0..1个黑色节点的划分方案数。
当节点i为白色时,对于它的每个孩子的节点处理:
求dp(i, 0)时有:
1,将该节点与孩子节点相连,但要保证孩子节点所在的子树种没有黑色节点;
2,将该节点不与该孩子节点相连,则该孩子节点要保证所在子树种有黑色节点;
即dp(i, 0) = π(dp(j,0 ) + dp(j, 1)) ,其中j为i的孩子节点
求dp(i,1)时有:
将该节点与其中每个孩子节点中的一个相连,并且保证该孩子节点所在子树中有1个黑色节点(所以共有K种情况,K为该节点的孩子数),并且对于剩下的节点可以选择连也可以选择不连,如果连接,则保证该子节点所在子树中没有黑色,如果不连,则要保证有黑色,所以对于剩下的每个
子节点的处理方案书有dp(j,0) + dp(j,1)个,然后将每个孩子处理的方案书相乘即可,最后将所有的方案相加即可。
当节点i为黑色的时候,求dp(i, 0) 肯定是0;
求dp(i, 1)时对于i的每个子节点也是有两种选择,连或者不连,如果连接,则保证该子节点所在子树中没有黑色,如果不连,则要保证有黑色,即对于每个子节点的处理数共有
dp(j, 0) + dp(j, 1)个,然后将每个孩子处理的方案数相乘。
最终dp(0,1)即为答案,这里假设0节点为根节点。
过程中可以加个小小的优化,当一个子节点所在的整棵子树中若没有黑色节点,那么该节点肯定与其父节点相连,所以计算时可以不考虑该节点。
#include <stdlib.h> #include <stdio.h> #include <algorithm> #include <vector> using namespace std; //int values[500001]; //long long sums[500001]; #define MODVALUE 1000000007 #define MOD(x) if((x) > MODVALUE) x %= MODVALUE; struct Edge { int to; int i; int totalcolor; Edge() { totalcolor = 0; } }; int compp(const void* a1, const void* a2) { return *((int*)a2) - *((int*)a1); } vector<Edge> G[100001]; int Color[100001]; long long res[100001][2]; //int TMP[100001]; bool Visited[100001]; void AddEdge(int from, int to) { Edge edge; edge.to = to; edge.i = G[to].size(); G[from].push_back(edge); edge.to = from; edge.i = G[from].size() - 1; G[to].push_back(edge); } int CountColor(int node) { Visited[node] = true; int count = 0; if (Color[node]) { count = 1; } for (int i = 0; i < G[node].size();i++) { Edge& edge = G[node][i]; if (!Visited[edge.to]) { edge.totalcolor = CountColor(edge.to); count += edge.totalcolor; } } return count; } void GetAns(int node) { Visited[node] = true; long long ans = 1; int countofcolor = 0; vector<int> TMP; for (int i = 0; i < G[node].size(); i++) { Edge& edge = G[node][i]; if (Visited[edge.to]) { continue; } //TMP[countofcolor++] = i; GetAns(edge.to); if (edge.totalcolor) { TMP.push_back(i); countofcolor++; //TMP[countofcolor++] = i; } } res[node][0] = 0; res[node][1] = 0; long long tmp1 = 1; long long tmp0 = 1; if (!Color[node]) { tmp1 = 0; } for (int i = 0; i < countofcolor; i++) { if (Color[node]) { Edge& edge = G[node][TMP[i]]; tmp1 *= (res[edge.to][1] + res[edge.to][0]); MOD(tmp1); tmp0 = 0; } else { Edge& edge1 = G[node][TMP[i]]; tmp0 *= (res[edge1.to][1] + res[edge1.to][0]); MOD(tmp0); long long tmp3 = 1; for (int j = 0; j < countofcolor; j++) { Edge& edge = G[node][TMP[j]]; if (i == j) { tmp3 *= res[edge.to][1]; MOD(tmp3); } else { tmp3 *= (res[edge.to][1] + res[edge.to][0]); MOD(tmp3); } } tmp1 += tmp3; } if (i == countofcolor - 1) { res[node][0] += tmp0; res[node][1] += tmp1; MOD(res[node][0]); MOD(res[node][1]); } } if (countofcolor == 0) { res[node][0] = Color[node] ? 0 : 1; res[node][1] = Color[node] ? 1 : 0; } } int main() { #ifdef _DEBUG freopen("e:\\in.txt", "r", stdin); #endif // _DEBUG int n; scanf("%d", &n); for (int i = 0; i < n - 1; i++) { int value; scanf("%d", &value); AddEdge(i + 1, value); } for (int i = 0; i < n; i++) { int value; scanf("%d", &value); Color[i] = value; } memset(Visited, 0, sizeof(Visited)); CountColor(0); memset(Visited, 0, sizeof(Visited)); GetAns(0); printf("%I64d\n", res[0][1]); return 0; }