题意
给定一棵 \(N\) 个节点的树,两个点之间都为 \(1\),一只蜗牛从树的根节点出发寻找自己的壳,壳只可能等概率的出现在叶子节点,某些节点可能有蚯蚓,它可以告诉蜗牛壳是否在该节点的子树内。求蜗牛在最优策略下找到壳所需要的期望步数。多组测试数据。
数据范围
\(1 \leq N \leq 1000\)。
思路
如果用 \(lea[i]\) 表示在 \(i\) 子树内的叶子节点的个数,用 \(f[i]\) 表示当壳在 \(i\) 的子树内时,用最优策略找到壳的所有情况的步数之和,那么最终的答案就是 \(f[root]/lea[i]\) (根据期望的定义)。
考虑如何从 \(f[v]\) 转移到 \(f[u]\)。观察一下题目给出的样例:
在样例中,假定先访问 \(4\) 号节点,再访问 \(5\) 号节点,那么从 \(1\) 出发,直接走到 \(3\) 号节点时,所有方案的步数之和就是 \(2+4\) (\(2\) 的路径是 \(1->3->4\),\(4\) 的路径是 \(1 ->3->4->3-5\))。那么就可以发现,如果 \(v\) 是直接从 \(u\) 走的第一个节点,那么 \(v\) 对 \(f[u]\) 的贡献为 \(f[v]+lea[v]\) (因为壳在 \(v\) 内有 \(lea[v]=2\) 种情况,这两种情况都需要从 \(u\) 走向 \(v\) 一次)。
用 \(back[i]\) 来表示从 \(i\) 出发后,没有在子树内找到壳,再回到 \(i\) 所走的距离。当然,如果 \(i\) 节点有蚯蚓,那么 \(back[i]=0\)。如在样例中,\(back[3]=0\)。而如果壳在 \(2\) 号节点,那么蜗牛走到 \(3\) 号节点的时候就会原路返回,走过的路径就是 \(1 -> 3 -> 1 ->2\),可以发现,此时需要走的步数就是 \((back[3]+2)*lea[v]+lea[2]+f[2]\)。(\(+2\) 是因为之前走的路径是 \(u -> back[v] -> u\),\(* lea[v]\) 是因为有 \(lea[v]\) 种情况)。
综合上面的讨论,可以得出状态转移方程:
\(f[u]=\sum_{i=1}^{son[u]} (\sum_{j=1}^{i-1}(back[j]+2))*lea[i]+f[i]\)。
而现在希望通过改变枚举 \(i\) 的顺序来减小 \(f[u]\) 的值。
假设现在有 \(x\) 和 \(y\) 两个节点,先枚举 \(x\) 再枚举 \(y\) 对答案的贡献:
\((\sum_{j=1}^{x-1}(back[j]+2))*lea[x]+f[x]+(\sum_{j=1}^{x-1}(back[j]+2)+back[x]+2)*lea[y]+f[y]\)。
先枚举 \(y\) 再枚举 \(x\) 对答案的贡献:
\((\sum_{j=1}^{y-1}(back[j]+2))*lea[y]+f[y]+(\sum_{j=1}^{y-1}(back[j]+2)+back[y]+2)*lea[x]+f[x]\)。
上下相减,可以得到:
\((back[x]+2)*lea[y] - (back[y]+2)*lea[x]\)。
于是就可以将叶子节点按照 \((back[x]+2)*lea[y]\) 从小到大排序,这样得到的就是最优策略。
code:
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1010;
int n,f[N],lea[N],root,siz[N],back[N];
vector<int> e[N];
bool yy[N];
void init()
{
memset(f,0,sizeof(f));
memset(lea,0,sizeof(lea));
memset(back,0,sizeof(back));
memset(siz,0,sizeof(siz));
memset(yy,0,sizeof(yy));
for(int i=1;i<=n;i++) e[i].clear();
}
bool cmp(int a,int b)
{
return (back[a]+2)*lea[b]<(back[b]+2)*lea[a];
}
void dfs(int u)
{
siz[u]=1;back[u]=lea[u]=0;
if(!e[u].size())
{
lea[u]=1;
return ;
}
for(int i=0;i<e[u].size();i++)
{
int v=e[u][i];
dfs(v);
lea[u]+=lea[v];
siz[u]+=siz[v];
}
// printf("%d\n",u);
sort(e[u].begin(),e[u].end(),cmp);
for(int i=0;i<e[u].size();i++)
{
int v=e[u][i];
// printf(" %d",v);
f[u]+=back[u]*lea[v]+f[v]+lea[v];
// if(u==1) printf("%d\n",f[u]);
back[u]+=back[v]+2;
}
// puts("");
if(yy[u]) back[u]=0;
}
int main()
{
// freopen("nlc.in","r",stdin);
while(scanf("%d",&n),n)
{
init();
for(int u,t,i=1;i<=n;i++)
{
char c;
scanf("%d %c",&u,&c);
if(u==-1) root=i;
else e[u].push_back(i);
if(c==‘Y‘) yy[i]=true;
}
dfs(root);
printf("%.4f\n",f[root]*1.0/lea[root]);
}
return 0;
}