题目
题目链接:https://www.luogu.com.cn/problem/P6329
在一片土地上有 \(n\) 个城市,通过 \(n-1\) 条无向边互相连接,形成一棵树的结构,相邻两个城市的距离为 \(1\),其中第 \(i\) 个城市的价值为 \(value_i\)。
不幸的是,这片土地常常发生地震,并且随着时代的发展,城市的价值也往往会发生变动。
接下来你需要在线处理 \(m\) 次操作:
0 x k
表示发生了一次地震,震中城市为 \(x\),影响范围为 \(k\),所有与 \(x\) 距离不超过 \(k\) 的城市都将受到影响,该次地震造成的经济损失为所有受影响城市的价值和。
1 x y
表示第 \(x\) 个城市的价值变成了 \(y\) 。
为了体现程序的在线性,操作中的 \(x\)、\(y\)、\(k\) 都需要异或你程序上一次的输出来解密,如果之前没有输出,则默认上一次的输出为 \(0\) 。
思路
点分树就是在点分治的基础上,将每次跳的重心与上一次跳的重心连边,构成一棵点分树。也就是一个点 \(x\) 的子节点是点分治时以 \(x\) 为重心的子树扔掉点 \(x\) 后,其余所有的树的重心。
由于点分治只会递归 \(\log n\) 层,所以点分树的深度也是 \(O(\log n)\) 的。
对于本题,构建出点分树,对于每一个点 \(x\),我们维护两棵动态开点线段树,第一棵的一个区间 \([l,r]\) 表示在点分树以 \(x\) 为根的子树中,原树上与 \(x\) 距离在 \([l,r]\) 的点的权值和;第二棵线段树区间 \([l,r]\) 表示在点分树以 \(x\) 为根的子树中,原树上与 \(x\) 在点分树上的父亲之间的距离在 \([l,r]\) 的点的权值和。
对于修改操作,我们从点 \(x\) 不断往点分树上父亲跳,然后维护两棵线段树的值即可。
对于询问操作,我们依然从点 \(x\) 开始网上跳,对于跳到的一个节点 \(a\),设上一个调到的节点 \(b\),那么 \(a\) 会造成的贡献为距离 \(a\) 不超过 \(k-dis_{a,x}\) 的点。但是在 \(b\) 中已经有一部分点背计算过了,这样就会导致重复计算,所以还要减去 \(b\) 的第二棵线段树中不超过 \(k-dis_{a,x}\) 的点。
这样就可以在 \(O(n\log^2n)\) 的复杂度内计算出答案了。
由于这种做法常数较大,我们可以用 ST 表预处理 LCA,每次询问可以 \(O(1)\) 查,并且动态开点线段树可以改为离散化后的树状数组。注意每一个树状数组的大小应当分别离散化,且离散化后大小应为其点分树内子树大小。这样空间复杂度是 \(O(n\log n)\) 的。
代码
#include <bits/stdc++.h>
using namespace std;
const int N=200010,LG=18,Inf=1e9;
int head[N],size[N],dfn[N],maxp[N],fa[N],dep[N],val[N],lg[N],st[N][LG+1];
int n,m,tot,rt,last;
bool vis[N];
vector<int> dis[2][N];
struct edge
{
int next,to;
}e[N*2];
void add(int from,int to)
{
e[++tot]=(edge){head[from],to};
head[from]=tot;
}
void dfs1(int x,int f)
{
st[++tot][0]=x; dfn[x]=tot; dep[x]=dep[f]+1;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=f)
{
dfs1(v,x);
st[++tot][0]=x;
}
}
}
void getst()
{
for (int i=tot;i>=1;i--)
for (int j=1;i+(1<<j)-1<=tot;j++)
if (dep[st[i][j-1]]<dep[st[i+(1<<j-1)][j-1]])
st[i][j]=st[i][j-1];
else
st[i][j]=st[i+(1<<j-1)][j-1];
}
int lca(int x,int y)
{
if (dfn[x]>dfn[y]) swap(x,y);
int k=lg[dfn[y]-dfn[x]+1];
if (dep[st[dfn[x]][k]]<dep[st[dfn[y]-(1<<k)+1][k]])
return st[dfn[x]][k];
else
return st[dfn[y]-(1<<k)+1][k];
}
void findrt(int x,int f,int sum)
{
size[x]=1; maxp[x]=0;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (!vis[v] && v!=f)
{
findrt(v,x,sum);
size[x]+=size[v];
if (size[v]>maxp[x]) maxp[x]=size[v];
}
}
if (sum-size[x]>maxp[x]) maxp[x]=sum-size[x];
if (maxp[x]<maxp[rt] || !rt) rt=x;
}
void dfs2(int x,int f,int sum)
{
fa[x]=f; vis[x]=1;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (!vis[v])
{
rt=0;
int s=(size[v]<size[x]) ? size[v] : sum-size[x];
findrt(v,x,s);
dfs2(rt,x,s);
}
}
}
int getdis(int x,int y)
{
return dep[x]+dep[y]-dep[lca(x,y)]*2;
}
struct BIT
{
vector<int> c;
void add(int x,int v)
{
for (int i=x;i<c.size();i+=i&-i)
c[i]+=v;
}
int query(int x)
{
int sum=0;
for (int i=x;i;i-=i&-i)
sum+=c[i];
return sum;
}
}bit[2][N];
void update(int x,int v)
{
for (int i=x;i;i=fa[i])
{
int p1=upper_bound(dis[0][i].begin(),dis[0][i].end(),getdis(x,i))-dis[0][i].begin();
bit[0][i].add(min(p1,(int)dis[0][i].size()),v-val[x]);
if (fa[i])
{
int p2=upper_bound(dis[1][i].begin(),dis[1][i].end(),getdis(fa[i],x))-dis[1][i].begin();
bit[1][i].add(min(p2,(int)dis[1][i].size()),v-val[x]);
}
}
val[x]=v;
}
int query(int x,int k)
{
int ans=0;
for (int i=x,j=0;i;j=i,i=fa[i])
{
int d=getdis(x,i);
if (d>k) continue;
int p1=upper_bound(dis[0][i].begin(),dis[0][i].end(),k-d)-dis[0][i].begin();
ans+=bit[0][i].query(min(p1,(int)dis[0][i].size()));
if (j)
{
int p2=upper_bound(dis[1][j].begin(),dis[1][j].end(),k-d)-dis[1][j].begin();
ans-=bit[1][j].query(min(p2,(int)dis[1][j].size()));
}
}
return ans;
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++)
scanf("%d",&val[i]);
for (int i=1,x,y;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
tot=0; dfs1(1,0);
getst();
for (int i=2;i<=tot;i++)
lg[i]=lg[i>>1]+1;
findrt(1,0,n);
dfs2(rt,0,n);
for (int i=1;i<=n;i++)
for (int j=i;j;j=fa[j])
{
dis[0][j].push_back(getdis(i,j));
if (fa[j]) dis[1][j].push_back(getdis(i,fa[j]));
}
for (int i=1;i<=n;i++)
{
sort(dis[0][i].begin(),dis[0][i].end());
sort(dis[1][i].begin(),dis[1][i].end());
unique(dis[0][i].begin(),dis[0][i].end());
unique(dis[1][i].begin(),dis[1][i].end());
for (int j=0;j<=dis[0][i].size()+1;j++)
bit[0][i].c.push_back(0);
for (int j=0;j<=dis[1][i].size()+1;j++)
bit[1][i].c.push_back(0);
}
for (int i=1;i<=n;i++)
{
int temp=val[i]; val[i]=0;
update(i,temp);
}
while (m--)
{
int opt,x,y;
scanf("%d%d%d",&opt,&x,&y);
x^=last; y^=last;
if (!opt) printf("%d\n",last=query(x,y));
else update(x,y);
}
return 0;
}