Code:
#include <cstdio> #include <algorithm> #define N 200005 #define mod 998244353 #define ll long long #define setIO(s) freopen(s".in","r",stdin) using namespace std; ll qpow(ll base,ll k) { ll tmp=1; for(;k;base=base*base%mod,k>>=1) if(k&1) tmp=tmp*base%mod; return tmp; } ll inv(ll k) { return qpow(k,mod-2); } int n,edges; ll f[N],g[N]; int hd[N],to[N],nex[N],size[N]; void add(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; } void dfs(int u) { int i; size[u]=1; ll sum=1; f[u]=g[u]=1; for(i=hd[u];i;i=nex[i]) { int v=to[i]; dfs(v); size[u]+=size[v]; sum=sum*f[v]%mod; f[u]=f[u]*((f[v]+g[v])%mod)%mod; } g[u]=f[u]; if(size[u]>1) { g[u]=(g[u]+mod-sum)%mod; for(i=hd[u];i;i=nex[i]) { int v=to[i]; ll tmp=inv(f[v])*g[v]%mod; tmp=tmp*sum%mod; f[u]=(f[u]+mod-tmp)%mod; } } } int main() { int i,j; // setIO("input"); scanf("%d",&n); for(i=2;i<=n;++i) { int a; scanf("%d",&a),add(a,i); } dfs(1); printf("%lld\n",f[1]); return 0; }