【CF860E】Arkady and a Nobody-men
题意:给你一棵n个点的有根树。如果b是a的祖先,定义$r(a,b)$为b的子树中深度小于等于a的深度的点的个数(包括a)。定义$z(a)=\sum\limits r(a,b)$(b是a的祖先)。要你求出每个点的z值。
$n\le 5\times 10^5$
题解:一开始naive的思路:将所有点按深度排序,将深度相同的点统一处理,统计答案时相当于链加,链求和,用树剖+树状数组搞一搞,时间复杂度$O(n\log^2n)$。
后来看题解发现我这个想法简直菜爆了。我们先从树形DP的角度去想,先给出转移方程:
$ans(x)=ans(fa(x))+dep(x)+ans'(x)$,ans'(x)表示与a深度相同的点 对a的贡献。
现在问题变成了求ans',我们考虑在每个点对的lca处统计贡献。具体地,我们对于每个点x,维护若干个三元组(d,a,cnt)表示x的子树中有cnt个d级子孙,其中一个子孙为a。DP的过程就相当于在父亲节点处将所有儿子节点的三元组合并,在合并时顺便统计贡献。
具体地,合并方式如下:假如x有两个儿子,它们有三元组$(d,a,cnt_a)$和$(d,b,cnt_b)$,则:
1.$ans'(a)+=dep(x)\times cnt_b$
2.$ans'(b)+=dep(x)\times cnt_a$
3.得到新三元组$(d,a,cnt_a+cnt_b)$
但是后面的点 对b的贡献呢?我们发现后面的点 对a和b的贡献就是相同的了,所以我们建一个新图,在新图中从a到b连一条长度为$ans'(b)-ans'(a)$的边,最后在新图上DFS一下,最最后统计一下ans数组即可。
以上过程采用长链剖分优化,由于一开始的三元组个数为n,则每次合并都会减少一个三元组,所以时间复杂度O(n)。
#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>
using namespace std;
const int maxn=500010;
typedef long long ll;
ll ans[maxn];
struct node
{
int v,x;
node() {}
node(int a,int b) {v=a,x=b;}
}mem[maxn<<1],*f[maxn],*now=mem;
int n,cnt,Cnt,rt;
int to[maxn],nxt[maxn],head[maxn],dep[maxn],md[maxn],son[maxn],fa[maxn];
bool vis[maxn];
int To[maxn],Nxt[maxn],Head[maxn];
ll Val[maxn];
inline void add(int a,int b)
{
to[cnt]=b,nxt[cnt]=head[a],head[a]=cnt++;
}
inline void Add(int a,int b,ll c)
{
To[Cnt]=b,Val[Cnt]=c,Nxt[Cnt]=Head[a],Head[a]=Cnt++;
}
void dfs1(int x)
{
md[x]=0;
for(int i=head[x];i!=-1;i=nxt[i])
{
dep[to[i]]=dep[x]+1,dfs1(to[i]);
if(md[to[i]]+1>md[x]) md[x]=md[to[i]]+1,son[x]=to[i];
}
}
void dfs2(int x)
{
if(f[x]==NULL) f[x]=now,now+=md[x]+2;
if(son[x]) f[son[x]]=f[x]+1,dfs2(son[x]);
f[x][0]=node(1,x);
for(int i=head[x];i!=-1;i=nxt[i]) if(to[i]!=son[x])
{
dfs2(to[i]);
for(int j=0;j<=md[to[i]];j++)
{
node a=f[x][j+1],b=f[to[i]][j];
ans[b.x]+=1ll*dep[x]*a.v;
ans[a.x]+=1ll*dep[x]*b.v;
Add(a.x,b.x,ans[b.x]-ans[a.x]);
f[x][j+1]=node(a.v+b.v,a.x);
}
}
}
void dfs3(int x)
{
for(int i=Head[x];i!=-1;i=Nxt[i]) ans[To[i]]=ans[x]+Val[i],dfs3(To[i]);
}
void dfs4(int x)
{
for(int i=head[x];i!=-1;i=nxt[i]) ans[to[i]]+=ans[x]+dep[x],dfs4(to[i]);
}
inline int rd()
{
int ret=0,f=1; char gc=getchar();
while(gc<'0'||gc>'9') {if(gc=='-') f=-f; gc=getchar();}
while(gc>='0'&&gc<='9') ret=ret*10+(gc^'0'),gc=getchar();
return ret*f;
}
int main()
{
n=rd();
int i;
memset(head,-1,sizeof(head)),memset(Head,-1,sizeof(Head));
for(i=1;i<=n;i++)
{
fa[i]=rd();
if(!fa[i]) rt=i;
else add(fa[i],i);
}
dep[rt]=1,dfs1(rt);
dfs2(rt);
for(i=0;i<=md[rt];i++) dfs3(f[rt][i].x);
dfs4(rt);
for(i=1;i<=n;i++) printf("%lld ",ans[i]);
return 0;
}