没想到最后一步。还是太菜了。
简要题意
-
\(n\) 种卡牌,每一轮抽到第 \(i\) 种卡牌的概率为 \(\dfrac{w_i}{\sum w_j}\),其中 \(\forall j\in\{1,2,3\}\),\(w_i\) 有 \(p_{i,j}\) 的概率为 \(j\)。
-
设第一次抽到 \(i\) 的时间为 \(T_i\),有 \(n-1\) 对限制构成一棵有向树,满足对于所有树边 \(u_i\to v_i\),\(T_{u_i} \lt T_{v_i}\)。求满足所有限制的概率。
-
\(n \le 1000\)
题解
若钦定树根,这棵树有着纷繁错杂的连边关系:根向、叶向可以交替出现,这就为解题带来很大不便。为此,先考虑所有边方向一致(不妨假定为根向,叶向同理)。
需要注意的是,每种卡牌出现的概率并不独立,这就意味着不能用简单的期望代替 \(w_i\)。由于边是根向的,题目中的限制等价于:任意结点先于子树中的结点出现(记为条件 A)。而当子树中的 \(w_i\) 确定时,条件 A 成立的概率是确定的:\(P(A) = \dfrac{w_u}{\sum\limits_{v\in \text{subtree}\ u}w_v}\)。因此只需记录子树的 \(w_v\) 之和即可 dp:\(f_{u,i}\) 表示以 \(u\) 为根的子树中,\(w_v\) 之和为 \(i\) 且在该子树中满足所有限制的概率。则: \(f^{'}_{u,i+j} = \dfrac{1}{i+j}\sum_{i,j}f_{u,i}\cdot f_{v,j}\),初值 \(f_{u,j} = p_{u,j}\cdot j(j \le 3)\)。
当加入反向边之后,情况就不再这么简单——直接做会导致系数乱套。考虑一个容斥:\(P(反向) = P(不存在)-P(正向)\) 。于是在树边上记一个容斥系数即可 dp。
Code
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int MAXN = 1010;
const int MOD = 998244353;
inline int add(int a, int b) {return (a+=b)>=MOD?a-MOD:a;}
inline void inc(int& a, int b) {a = add(a, b);}
inline int sub(int a, int b) {return (a-=b)<0?a+MOD:a;}
inline void dec(int& a, int b) {a = sub(a, b);}
inline int mul(int a, int b) {return 1ll*a*b%MOD;}
inline void mlt(int& a, int b) {a = mul(a, b);}
inline int pw(int x, int p = MOD-2)
{
int res = 1;
for(;p;p>>=1,mlt(x,x)) if(p&1) mlt(res, x);
return res;
}
int inv[MAXN*3];
inline void initInv(int n)
{
inv[1] = 1;
for(int i=2;i<=n;i++) inv[i] = mul(MOD-MOD/i, inv[MOD%i]);
}
int n, p[MAXN][4];
struct edge{
int ne, to, w;
edge(int N=0,int T=0,int W=0):ne(N),to(T),w(W){}
}e[MAXN<<1];
int fir[MAXN], num = 0;
inline void join(int a, int b, int c)
{
e[++num] = edge(fir[a], b, c);
fir[a] = num;
}
int f[MAXN][MAXN*3], siz[MAXN], tmp[MAXN*3];
void dfs(int u, int fa)
{
for(int i=1;i<=3;i++) f[u][i] = mul(p[u][i], i);
siz[u] = 3;
for(int i=fir[u];i;i=e[i].ne)
{
int v = e[i].to;
if(v == fa) continue;
dfs(v, u);
memset(tmp, 0, (siz[u]+siz[v]+2)<<2);
for(int j=1;j<=siz[u];j++)
for(int k=1;k<=siz[v];k++)
{
inc(tmp[j+k], mul(mul(f[u][j], f[v][k]), e[i].w));
if(e[i].w != 1) inc(tmp[j], mul(f[u][j], f[v][k]));
}
siz[u] += siz[v];
for(int j=1;j<=siz[u];j++) f[u][j] = tmp[j];
}
for(int i=1;i<=siz[u];i++) mlt(f[u][i], inv[i]);
}
inline void work()
{
scanf("%d",&n);
initInv(n*3);
for(int i=1;i<=n;i++)
{
int s = 0;
for(int j=1;j<=3;j++) scanf("%d",p[i]+j), inc(s, p[i][j]);
s = pw(s);
for(int j=1;j<=3;j++) mlt(p[i][j], s);
}
for(int i=1,u,v;i<n;i++)
{
scanf("%d%d",&u,&v);
join(u, v, 1);
join(v, u, MOD-1);
}
dfs(1, 0);
int ans = 0;
for(int i=1;i<=n*3;i++) inc(ans, f[1][i]);
printf("%d\n",ans);
}
int main()
{
int T = 1;
// scanf("%d%d",&T,&MOD);
// prework(N);
while(T--) work();
return 0;
}