21牛客9G - Glass Balls (树上概率dp)

题目

source

题解

对于从\(u\)点出发掉到\(v\)点的球来说,它的贡献是\(dep[u]-dep[v]\)。设对于一个固定的局面,掉到\(v\)点的球的球的个数为\(cnt[v]\),那么所有球的贡献为(即该局面的分数)为:

\[\sum\limits_{i=1}^{n}{dep[i]-\sum\limits_{i=1}^{n}{cnt[i]\cdot dep[i]}} \]

因此,只要分别求出深度总和的期望每个结点掉下去球数的期望即可,可以用树上dp计算。这里有几点要注意的:

  • 局面有合法的情况和非法的情况,因此在转移状态时注意确保的是从合法的子状态以合法的过程转移过来。
  • 树上dp一般计算的是子树的结果,在合并统计答案时要考虑上子树外部分的影响,这也是为什么往往需要两个dfs计算down和up的原因。

从题目中可以容易推得,每个结点的子节点中至多只有一个结点不是“储存点”,否则就是非法的。

设\(dp[i]\)为点\(i\)的子树中到\(i\)的球数的期望;\(down[i]\)为点\(i\)的子树为合法局面的概率;\(up[i]\)为整棵树在点\(i\)​为“储存点”时且除去了\(down[i]\)​的合法概率。这里的\(up[i]\)是为了将\(i\)子树中到点\(i\)的球数的期望转换为整棵树中从\(i\)掉下去的球数的期望,即\(cnt[i]=up[i] \times dp[i]\)。

显然,深度总和的期望就是整棵树合法的概率乘上深度的总和,即\(down[1] \times \sum\limits_{i=1}^n{dep[i]}\)。

\(down\)和\(up\)的转移都比较简单,主要是\(dp\)的转移。设\(P\)为“储存点的概率”,\(t\)为点p子结点的个数。

  • 子结点都是“储存点”,且子节点都合法,此时\(u\)中只有1个球,这种情况的贡献为:

\[dp[u]=1 \times P^t \times \prod_{v {\rm 是}u{\rm的子节点}} {down[v]} \]

  • 子结点\(v\)​不是”储存点“,且子节点都合法,此时\(u\)中除了本身的1个球,还有来自\(dp[v]\)那么多的球,这种情况的贡献为:

\[dp[u]=dp[v]\times P^{t-1}\times (1-P)\times \prod_{v'\neq v}{down[v']}+1\times P^{t-1}\times (1-P)\times\prod_{v' {\rm 是}u{\rm的子节点}} {down[v']} \]

最终答案为:\(down[1] \times \sum\limits_{i=1}^n{dep[i]}-\sum\limits_{i=1}^{n}{up[i] \times dp[i]\cdot dep[i]}\)

#include <bits/stdc++.h>

#define endl '\n'
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define mp make_pair
#define seteps(N) fixed << setprecision(N) 
typedef long long ll;

using namespace std;
/*-----------------------------------------------------------------*/

ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
#define INF 0x3f3f3f3f

const int N = 5e5 + 10;
const int M = 998244353;
const double eps = 1e-5;

ll down[N];
ll up[N];
ll dp[N];
int dep[N];
ll po;
vector<int> np[N];

inline ll qpow(ll a, ll b, ll m) {
    ll res = 1;
    while(b) {
        if(b & 1) res = (res * a) % m;
        a = (a * a) % m;
        b = b >> 1;
    }
    return res;
}

void dfs(int p, int fa, int d) {
    dep[p] = d;
    for(int nt : np[p]) {
        if(nt == fa) continue;
        dfs(nt, p, d + 1);
    }
}

void caldown(int p, int fa) {
    ll lp = 1;
    int num = 0;
    for(int nt : np[p]) {
        if(nt == fa) continue;
        num++;
        caldown(nt, p);
        lp = lp * down[nt] % M;
    }
    if(num)
        lp = lp * (qpow(po, num - 1, M) * (1 - po + M) % M * num % M + qpow(po, num, M)) % M;
    down[p] = lp;
}

void calup(int p, int fa) {
    int num = 0;
    for(int nt : np[p]) {
        if(nt == fa) continue;
        num++;
        up[nt] = down[1] * qpow(down[nt], M - 2, M) % M;
    }
    if(num) {
        ll tp = (qpow(po, num - 1, M) * (1 - po + M) % M * num % M + qpow(po, num, M)) % M;
        for(int nt : np[p]) {
            if(nt == fa) continue;
            up[nt] = up[nt] * qpow(tp, M - 2, M) % M;
            up[nt] = up[nt] * (qpow(po, num - 1, M) * (1 - po + M) % M * (num - 1) % M + qpow(po, num, M)) % M;
            calup(nt, p);
        }
    }
}

void solve(int p, int fa) {
    int num = 0;
    ll lp = 1;
    for(int nt : np[p]) {
        if(nt == fa) continue;
        num++;
        lp = lp * down[nt] % M;
        solve(nt, p);
    }
    dp[p] = qpow(po, num, M) * lp % M;
    if(num)
        for(int nt : np[p]) {
            if(nt == fa) continue;
            // 注意后面1的贡献
            // 不要写成(dp[nt] + 1) * qpow(po, num - 1, M) % M * (1 - po + M) % M * lp % M * qpow(down[nt], M - 2, M) % M
            dp[p] += dp[nt] * qpow(po, num - 1, M) % M * (1 - po + M) % M * lp % M * qpow(down[nt], M - 2, M) % M + qpow(po, num - 1, M) % M * (1 - po + M) % M * lp % M;
            // 也可以写成
            // dp[p] += (dp[nt] + down[nt]) * qpow(po, num - 1, M) % M * (1 - po + M) % M * lp % M * qpow(down[nt], M - 2, M) % M;
            
            dp[p] %= M;
        }
}

int main() {
    IOS;
    up[1] = 1;
    int n;
    cin >> n >> po;
    for(int i = 2; i <= n; i++) {
        int f;
        cin >> f;
        np[i].push_back(f);
        np[f].push_back(i);
    }
    dfs(1, 0, 1);
    caldown(1, 0);
    calup(1, 0);
    solve(1, 0);
    ll ans = 0;
    ll tp = down[1];
    for(int i = 1; i <= n; i++) {
        ans = (ans + (tp - up[i] * (dp[i]) % M + M) * dep[i] % M) % M;
    }
    cout << ans << endl;
}
上一篇:idea debug---启动超级慢,提示”Method breakpoints may dramatically slow down debugging“的解决办法


下一篇:折纸问题