[ARC086E]Smuggling Marbles(树形dp+启发式合并)
题面
Sunke有一棵N + 1个点的树,其中0为根,每个点上有0或1个石子,Sunke会不停的进行如下操作直至整棵树没有石子:
把0上面的石子从树上拿走放入口袋;
把每个点上的石子移到其父亲上;
对于每个点,若其石子数≥ 2,则移除该点所有石子(不放入口袋)。
求所有\(2^{N+1}\)种放置石子的方案的最终Sunke口袋中石子数的和为多少.模\(10^9+7\)
分析
方案之和不好算,考虑求概率再乘上\(2^{n+1}\)。
设\(dp[x][d][0]\)表示初始时距离\(x\)深度为\(d\)的石子移动到第\(x\)个点后,第\(x\)个点没有石子的概率。\(dp[x][d][1]\)为有1个石子的概率。\(dp[x][d][2]\)为有2个及以上的石子的概率。
设\(y \in \text{son}(x)\)那么将合并儿子节点的时候,对于\(d \geq 1\)容易写出转移:
\(dp[x][d][0]=dp[x][d][0] \times dp[y][d-1][0]\)
表示\(y\)节点没有石子移动上来
\(dp[x][d][1]=dp[x][d][1] \times dp[y][d-1][0]+dp[x][d][0]*dp[y][d-1][1]\)
表示可能是已经合并上的子树有一个石子放在\(x\)处,也可能是从\(y\)这里移动一个石子
\(dp[x][d][2]=\sum_{0 \leq j,k \leq 2,j+k=2} dp[x][d-1][j] \times dp[y][d-1][k]\),由于式子过长写成这样的形式,意义就是枚举原有的和新合并的石子个数,凑出2个及以上的石子。
合并完之后要处理大于等于2个石子的时候被去掉的情况。也就是把\(dp[x][d][0]\)设为\(dp[x][d][0]+dp[x][d][2]\),再把\(dp[x][d][2]\)置为0.表示大于等于2个石子去掉之后变成了0个石子的状态。
合并完儿子后还有新加进来的,当前节点的状态,\(dp[x][0][0]=dp[x][0][1]=\frac{1}{2},dp[x][0][2]=0\),表示初始的时候第\(x\)个节点可能有也可能没有石子,概率为\(\frac{1}{2}\).注意要用逆元。
最终期望为\(\sum dp[0][d][1]\),因为根据定义往上移动时,根节点有1个石子的时候会对答案产生1的贡献.
考虑如何优化这个合并过程。我们可以把\(dp[x]\)存储在一个vector
内,合并儿子的时候启发式合并,将对应深度的合并在一起。然后加入\(dp[x][0]\).注意为了\(O(1)\)添加元素,vector
里面的元素是倒序存储,最后一个元素对应\(dp[x][1]\).这样加入时直接push_back
就可以了.
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#define maxn 3000000
#define mod 1145141
using namespace std;
typedef long long ll;
inline void qread(int &x){
x=0;
int sign=1;
char c=getchar();
while(c<'0'||c>'9'){
if(c=='-') sign=-1;
c=getchar();
}
while(c>='0'&&c<='9'){
x=x*10+c-'0';
c=getchar();
}
x=x*sign;
}
inline ll fast_pow(ll x,ll k){
ll ans=1;
while(k){
if(k&1) ans=ans*x%mod;
x=x*x%mod;
k>>=1;
}
return ans;
}
const ll inv2=fast_pow(2,mod-2);
int n;
struct edge {
int from;
int to;
int next;
} E[maxn*2+5];
int head[maxn+5];
int sz=1;
void add_edge(int u,int v) {
sz++;
E[sz].from=u;
E[sz].to=v;
E[sz].next=head[u];
head[u]=sz;
}
struct node {
ll f[3];//一个三元组存储dp值
node() {
}
node(ll x,ll y,ll z) {
f[0]=x;
f[1]=y;
f[2]=z;
}
inline ll& operator [] (int i) {
return f[i];
}
};
vector<node>v[maxn+5];
int id[maxn+5];
void merge(int x,int y) {//启发式合并节点
if(v[id[x]].size()<v[id[y]].size()) swap(id[x],id[y]);
x=id[x];
y=id[y];
int nx=v[x].size()-1,ny=v[y].size()-1;
for(int i=0; i<=ny; i++) {
ll sum0,sum1,sum2=0;
int tx=nx-i,ty=ny-i;//注意dp数组是倒序存储的
sum0=v[x][tx][0]*v[y][ty][0]%mod;//合并的dp方程见博客正文
sum1=(v[x][tx][1]*v[y][ty][0]+v[x][tx][0]*v[y][ty][1])%mod;
for(int j=0; j<=2; j++) {
for(int k=2; j+k>=2; k--) {
sum2=(sum2+v[x][tx][j]*v[y][ty][k]%mod)%mod;
}
}
v[x][tx]=node(sum0,sum1,sum2);
}
v[y].clear();
}
int ptr;
void dfs(int x) {
if(!head[x]) id[x]=++ptr;//叶子节点才分配
int maxd=0;
for(int i=head[x]; i; i=E[i].next) {
int y=E[i].to;
dfs(y);
if(!id[x]) id[x]=id[y];//类似长链剖分,直接继承某一个节点
else {
maxd=max(maxd,min((int)v[id[x]].size(),(int)v[id[y]].size()));//记录一下距x的最大深度,方便转移
merge(x,y);
}
}
int nx=v[id[x]].size()-1;
for(int i=0; i<maxd; i++) {
v[id[x]][nx-i][0]=(v[id[x]][nx-i][0]+v[id[x]][nx-i][2])%mod;
v[id[x]][nx-i][2]=0;
}
v[id[x]].push_back(node(inv2,inv2,0));//加入当前节点的状态
}
int main() {
int f;
qread(n);
for(int i=1; i<=n; i++) {
qread(f);
add_edge(f,i);
}
dfs(0);
ll ans=0;
for(int i=0;i<(int)v[id[0]].size();i++) (ans=ans+v[id[0]][i][1]);
ans=ans*fast_pow(2,n+1)%mod;
printf("%lld\n",ans);
}