首先有树形 DP 。
设 \(P_{u,val}\) 表示 \(u\) 节点值为 \(val\) 的概率,\(lc\) 表示左儿子, \(rc\) 表示右儿子。
有转移:
\(P_{u,val} \gets P_{lc,val} \times p_u \times \sum_{i=0}^{val-1} P_{rc,i}\)
\(P_{u,val} \gets P_{rc,val} \times p_u \times \sum_{i=0}^{val-1} P_{lc,i}\)
\(P_{u,val} \gets P_{lc,val} \times (1-p_u) \times \sum_{i=val+1}^{+\infty} P_{rc,i}\)
\(P_{u,val} \gets P_{rc,val} \times (1-p_u) \times \sum_{i=val+1}^{+\infty} P_{lc,i}\)
对叶子节点的值可以离散化,对式子进行前缀和优化后,设叶子节点个数为 \(m\) ,复杂度为 \(\mathcal{O}(n\times m)\) ,可以获得 50pts 的好成绩。
发现有前后缀和,考虑用线段树,用离散化后的线段树暴力维护,时间复杂度 \(\mathcal{O}(n\times m\log m)\) ,十分糟糕。
发现转移方程求和很奇特,是一个区间和的形式,这样在类似线段树建立的过程中,我们也可以在到达每个叶子节点时算出将要使用的式子,所以就要依赖这一过程进行转移。
「类似线段树建立的过程」\(\rightarrow\) 「线段树合并」
时间复杂度 \(\mathcal{O}(n\log m)\) 。
Code(C++):
#include<bits/stdc++.h>
#define forn(i,s,t) for(int i=(s);i<=(t);++i)
using namespace std;
const int N = 3e5+3,M = 1.2e7+3,Mod = 998244353,inv = 796898467;
char ch;
template<typename T>inline void redn(T &ret) {
ret=0,ch=getchar();
while(!isdigit(ch)) ch=getchar();
while(isdigit(ch)) ret=(ret<<1)+(ret<<3)+ch-48,ch=getchar();
}
int n,fa[N],p[N],son[N][2],rft[N],Rn,D[N];
struct SegTree {
int E[M],L[M],R[M],mul[M],rt[N],crash[M],cra,sl;
inline void init(int p) {
E[p] = L[p] = R[p] = 0;
mul[p] = 1;
}
inline int Ads() {return cra?crash[cra--]:++sl;}
inline void Del(int p) {crash[++cra]=p,init(p);}
inline void push_up(int p) {
E[p] = 1ll*(E[L[p]] + E[R[p]]) %Mod;
}
inline void mlt(int p,int val) {
E[p] = 1ll*E[p]*val %Mod;
mul[p] = 1ll*mul[p]*val %Mod;
}
inline void push_down(int p) {
if(mul[p] != 1)
mlt(L[p],mul[p]),mlt(R[p],mul[p]),
mul[p] = 1;
}
void Upd(int &p,int l,int r,int pos,int val) {
if(!p) p=Ads(),init(p);
if(l == r) {
E[p] = val;
return ;
}
int mid = l+r >> 1;
push_down(p);
if(pos<=mid) Upd(L[p],l,mid,pos,val);
else Upd(R[p],mid+1,r,pos,val);
push_up(p);
}
void Mrg(int &p,int pre,int sump,int sumpre,int P) {
if(!p||!pre) {
if(!p&&!pre) return ;
if(!p) mlt(pre,sumpre);
else mlt(p,sump);
p += pre;
return ;
}
push_down(p),push_down(pre);
int vall = E[L[p]],valr = E[R[p]];
int valll = E[L[pre]],valrr = E[R[pre]];
Mrg(L[p],L[pre],(1ll*sump + 1ll*(1-P+Mod)*valrr%Mod)%Mod,
(1ll*sumpre + 1ll*(1-P+Mod)*valr%Mod)%Mod,P);
Mrg(R[p],R[pre],(1ll*sump + 1ll*P*valll%Mod)%Mod,
(1ll*sumpre + 1ll*P*vall%Mod)%Mod,P);
Del(pre);
push_up(p);
}
void res(int p,int l,int r) {
if(!p) return ;
if(l == r) {
D[l] = E[p];
return ;
}
push_down(p);
int mid = l+r >> 1;
res(L[p],l,mid),res(R[p],mid+1,r);
}
}T;
void solve(int u) {
if(!son[u][0]) T.Upd(T.rt[u],1,Rn,p[u],1);
else if(son[u][1]) {
solve(son[u][0]),solve(son[u][1]);
T.rt[u] = T.rt[son[u][0]];
T.Mrg(T.rt[u],T.rt[son[u][1]],0,0,p[u]);
} else {
solve(son[u][0]);
T.rt[u] = T.rt[son[u][0]];
}
}
int main() {
redn(n);
forn(i,1,n) redn(fa[i]);
forn(i,1,n) redn(p[i]);
forn(i,1,n) if(son[fa[i]][0]) son[fa[i]][1] = i;
else son[fa[i]][0] = i;
forn(i,1,n) if(son[i][0]) p[i] = 1ll*p[i]*inv%Mod;
else rft[++Rn] = p[i];
sort(rft+1,rft+Rn+1); Rn = unique(rft+1,rft+Rn+1) - rft - 1;
forn(i,1,n) if(!son[i][0]) p[i] = lower_bound(rft+1,rft+Rn+1,p[i]) - rft;
solve(1);
T.res(T.rt[1],1,Rn);
int Ans = 0;
forn(i,1,Rn) Ans = 1ll*(Ans + 1ll*i*D[i]%Mod*D[i]%Mod*rft[i]%Mod)%Mod;
printf("%d\n",Ans);
return 0;
}