【洛谷P6329】【模板】点分树 | 震波

题目

题目链接:https://www.luogu.com.cn/problem/P6329
在一片土地上有 \(n\) 个城市,通过 \(n-1\) 条无向边互相连接,形成一棵树的结构,相邻两个城市的距离为 \(1\),其中第 \(i\) 个城市的价值为 \(value_i\)。

不幸的是,这片土地常常发生地震,并且随着时代的发展,城市的价值也往往会发生变动。

接下来你需要在线处理 \(m\) 次操作:

0 x k 表示发生了一次地震,震中城市为 \(x\),影响范围为 \(k\),所有与 \(x\) 距离不超过 \(k\) 的城市都将受到影响,该次地震造成的经济损失为所有受影响城市的价值和。

1 x y 表示第 \(x\) 个城市的价值变成了 \(y\) 。

为了体现程序的在线性,操作中的 \(x\)、\(y\)、\(k\) 都需要异或你程序上一次的输出来解密,如果之前没有输出,则默认上一次的输出为 \(0\) 。

思路

点分树就是在点分治的基础上,将每次跳的重心与上一次跳的重心连边,构成一棵点分树。也就是一个点 \(x\) 的子节点是点分治时以 \(x\) 为重心的子树扔掉点 \(x\) 后,其余所有的树的重心。
由于点分治只会递归 \(\log n\) 层,所以点分树的深度也是 \(O(\log n)\) 的。
对于本题,构建出点分树,对于每一个点 \(x\),我们维护两棵动态开点线段树,第一棵的一个区间 \([l,r]\) 表示在点分树以 \(x\) 为根的子树中,原树上与 \(x\) 距离在 \([l,r]\) 的点的权值和;第二棵线段树区间 \([l,r]\) 表示在点分树以 \(x\) 为根的子树中,原树上与 \(x\) 在点分树上的父亲之间的距离在 \([l,r]\) 的点的权值和。
对于修改操作,我们从点 \(x\) 不断往点分树上父亲跳,然后维护两棵线段树的值即可。
对于询问操作,我们依然从点 \(x\) 开始网上跳,对于跳到的一个节点 \(a\),设上一个调到的节点 \(b\),那么 \(a\) 会造成的贡献为距离 \(a\) 不超过 \(k-dis_{a,x}\) 的点。但是在 \(b\) 中已经有一部分点背计算过了,这样就会导致重复计算,所以还要减去 \(b\) 的第二棵线段树中不超过 \(k-dis_{a,x}\) 的点。
这样就可以在 \(O(n\log^2n)\) 的复杂度内计算出答案了。
由于这种做法常数较大,我们可以用 ST 表预处理 LCA,每次询问可以 \(O(1)\) 查,并且动态开点线段树可以改为离散化后的树状数组。注意每一个树状数组的大小应当分别离散化,且离散化后大小应为其点分树内子树大小。这样空间复杂度是 \(O(n\log n)\) 的。

代码

#include <bits/stdc++.h>
using namespace std;

const int N=200010,LG=18,Inf=1e9;
int head[N],size[N],dfn[N],maxp[N],fa[N],dep[N],val[N],lg[N],st[N][LG+1];
int n,m,tot,rt,last;
bool vis[N];
vector<int> dis[2][N];

struct edge
{
	int next,to;
}e[N*2];

void add(int from,int to)
{
	e[++tot]=(edge){head[from],to};
	head[from]=tot;
}

void dfs1(int x,int f)
{
	st[++tot][0]=x; dfn[x]=tot; dep[x]=dep[f]+1;
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=f)
		{
			dfs1(v,x);
			st[++tot][0]=x;
		}
	}
}

void getst()
{
	for (int i=tot;i>=1;i--)
		for (int j=1;i+(1<<j)-1<=tot;j++)
			if (dep[st[i][j-1]]<dep[st[i+(1<<j-1)][j-1]])
				st[i][j]=st[i][j-1];
			else
				st[i][j]=st[i+(1<<j-1)][j-1];
}

int lca(int x,int y)
{
	if (dfn[x]>dfn[y]) swap(x,y);
	int k=lg[dfn[y]-dfn[x]+1];
	if (dep[st[dfn[x]][k]]<dep[st[dfn[y]-(1<<k)+1][k]])
		return st[dfn[x]][k];
	else
		return st[dfn[y]-(1<<k)+1][k];
}

void findrt(int x,int f,int sum)
{
	size[x]=1; maxp[x]=0;
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (!vis[v] && v!=f)
		{
			findrt(v,x,sum);
			size[x]+=size[v];
			if (size[v]>maxp[x]) maxp[x]=size[v];
		}
	}
	if (sum-size[x]>maxp[x]) maxp[x]=sum-size[x];
	if (maxp[x]<maxp[rt] || !rt) rt=x;
}

void dfs2(int x,int f,int sum)
{
	fa[x]=f; vis[x]=1;
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (!vis[v])
		{
			rt=0;
			int s=(size[v]<size[x]) ? size[v] : sum-size[x];
			findrt(v,x,s);
			dfs2(rt,x,s);
		}
	}
}

int getdis(int x,int y)
{
	return dep[x]+dep[y]-dep[lca(x,y)]*2;
}

struct BIT
{
	vector<int> c;
	
	void add(int x,int v)
	{
		for (int i=x;i<c.size();i+=i&-i)
			c[i]+=v;
	}
	
	int query(int x)
	{
		int sum=0;
		for (int i=x;i;i-=i&-i)
			sum+=c[i];
		return sum;
	}
}bit[2][N];

void update(int x,int v)
{
	for (int i=x;i;i=fa[i])
	{
		int p1=upper_bound(dis[0][i].begin(),dis[0][i].end(),getdis(x,i))-dis[0][i].begin();
		bit[0][i].add(min(p1,(int)dis[0][i].size()),v-val[x]);
		if (fa[i])
		{
			int p2=upper_bound(dis[1][i].begin(),dis[1][i].end(),getdis(fa[i],x))-dis[1][i].begin();
			bit[1][i].add(min(p2,(int)dis[1][i].size()),v-val[x]);
		}
	}
	val[x]=v;
}

int query(int x,int k)
{
	int ans=0;
	for (int i=x,j=0;i;j=i,i=fa[i])
	{
		int d=getdis(x,i);
		if (d>k) continue;
		int p1=upper_bound(dis[0][i].begin(),dis[0][i].end(),k-d)-dis[0][i].begin();
		ans+=bit[0][i].query(min(p1,(int)dis[0][i].size()));
		if (j)
		{
			int p2=upper_bound(dis[1][j].begin(),dis[1][j].end(),k-d)-dis[1][j].begin();
			ans-=bit[1][j].query(min(p2,(int)dis[1][j].size()));
		}
	}
	return ans;
}

int main()
{
	memset(head,-1,sizeof(head));
	scanf("%d%d",&n,&m);
	for (int i=1;i<=n;i++)
		scanf("%d",&val[i]);
	for (int i=1,x,y;i<n;i++) 
	{
		scanf("%d%d",&x,&y);
		add(x,y); add(y,x);
	}
	tot=0; dfs1(1,0);
	getst();
	for (int i=2;i<=tot;i++)
		lg[i]=lg[i>>1]+1;
	findrt(1,0,n);
	dfs2(rt,0,n);
	for (int i=1;i<=n;i++)
		for (int j=i;j;j=fa[j])
		{
			dis[0][j].push_back(getdis(i,j));
			if (fa[j]) dis[1][j].push_back(getdis(i,fa[j]));
		}
	for (int i=1;i<=n;i++)
	{
		sort(dis[0][i].begin(),dis[0][i].end());
		sort(dis[1][i].begin(),dis[1][i].end());
		unique(dis[0][i].begin(),dis[0][i].end());
		unique(dis[1][i].begin(),dis[1][i].end());
		for (int j=0;j<=dis[0][i].size()+1;j++)
			bit[0][i].c.push_back(0);
		for (int j=0;j<=dis[1][i].size()+1;j++)
			bit[1][i].c.push_back(0);
	}
	for (int i=1;i<=n;i++)
	{
		int temp=val[i]; val[i]=0;
		update(i,temp);
	}
	while (m--)
	{
		int opt,x,y;
		scanf("%d%d%d",&opt,&x,&y);
		x^=last; y^=last;
		if (!opt) printf("%d\n",last=query(x,y));
			else update(x,y);
	}
	return 0;
}
上一篇:最长连续递增子序列 (25 分) 给定一个顺序存储的线性表,请设计一个算法查找该线性表中最长的连续递增子序列。例如,(1,9,2,5,7,3,4,6,8,0)中最长的递增子序列为(3,4,6,8)


下一篇:BZOJ-4242 水壶