题目
题目链接:https://codeforces.com/gym/102759/problem/I
给定一棵 \(n\) 个点的有根树,初始时每个点的点权为 \(0\)。
接下来会执行 \(Q\) 次操作,每次操作会是以下两种之一:
- 将 \(u\) 子树内所有顶点的点权增加 \(1\)。
- 将 \(u\) 到 \(v\) 路径上内所有顶点的点权增加 \(1\)。
在每次操作后,设顶点 \(u\) 的点权为 \(a_u\),则输出一个顶点 \(v\),使得
\[\sum^{n}_{u=1}\text{dis}(u,v)\times a_u \]最小。若有多个满足条件的顶点,输出深度最小的一个。
\(n,m\leq 10^5\)。
思路
也就是支持修改求深度最小的带权重心。
有一个并不显而易见的结论:深度最小的带权重心的子树权值和一定严格大于所有点权值和的一半。
证明并不难,假设深度最小带权重心为 \(x\),考虑 \(x\) 的外子树,它的大小一定严格小于所有点权值和的一半,否则 \(x\) 的父亲一定比 \(x\) 更优(至少一半的权值 \(-1\),剩余权值 \(+1\))。
任意取这棵树的一个 dfs 序,按照 dfs 序把每一个点编号写下来形成一个序列,其中节点 \(i\) 连续写 \(a_i\) 次。那么因为深度最小带权重心的子树大小严格大于一半,那么在序列中,这个点的子树所表示区间的长度也一定严格大于一半,所以序列最终间的数一定在答案点的子树内。
那么我们可以用树剖 + 线段树维护,找最中间的点可以用线段树二分找出(因为线段树上编号本来就是一个 dfs 序),那么答案一定是找到的点的一个祖先,倍增寻找即可。
时间复杂度 \(O(Q\log^2 n)\)。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=100010,LG=18;
int n,Q,tot,head[N],siz[N],id[N],rk[N],son[N],top[N],dep[N],f[N][LG+1];
struct edge
{
int next,to;
}e[N*2];
void add(int from,int to)
{
e[++tot]=(edge){head[from],to};
head[from]=tot;
}
void dfs1(int x,int fa)
{
siz[x]=1; f[x][0]=fa; dep[x]=dep[fa]+1;
for (int i=1;i<=LG;i++)
f[x][i]=f[f[x][i-1]][i-1];
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa)
{
dfs1(v,x); siz[x]+=siz[v];
if (siz[v]>siz[son[x]]) son[x]=v;
}
}
}
void dfs2(int x,int tp)
{
top[x]=tp; id[x]=++tot; rk[tot]=x;
if (son[x]) dfs2(son[x],tp);
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=f[x][0] && v!=son[x]) dfs2(v,v);
}
}
struct SegTree
{
ll sum[N*4],lazy[N*4];
void pushdown(int x,int l,int r)
{
if (lazy[x])
{
int mid=(l+r)>>1;
sum[x*2]+=lazy[x]*(mid-l+1); lazy[x*2]+=lazy[x];
sum[x*2+1]+=lazy[x]*(r-mid); lazy[x*2+1]+=lazy[x];
lazy[x]=0;
}
}
void update(int x,int l,int r,int ql,int qr)
{
if (ql<=l && qr>=r)
{
sum[x]+=r-l+1; lazy[x]++;
return;
}
pushdown(x,l,r);
int mid=(l+r)>>1;
if (ql<=mid) update(x*2,l,mid,ql,qr);
if (qr>mid) update(x*2+1,mid+1,r,ql,qr);
sum[x]=sum[x*2]+sum[x*2+1];
}
ll query1(int x,int l,int r,int ql,int qr)
{
if (ql<=l && qr>=r) return sum[x];
pushdown(x,l,r);
int mid=(l+r)>>1; ll res=0;
if (ql<=mid) res+=query1(x*2,l,mid,ql,qr);
if (qr>mid) res+=query1(x*2+1,mid+1,r,ql,qr);
return res;
}
int query2(int x,int l,int r,ll s)
{
if (l==r) return rk[l];
pushdown(x,l,r);
int mid=(l+r)>>1;
if (sum[x*2]>=s) return query2(x*2,l,mid,s);
else return query2(x*2+1,mid+1,r,s-sum[x*2]);
}
}seg;
int main()
{
memset(head,-1,sizeof(head));
scanf("%d",&n);
for (int i=1,x,y;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
tot=0;
dfs1(1,0); dfs2(1,1);
scanf("%d",&Q);
while (Q--)
{
int opt,x,y;
scanf("%d",&opt);
if (opt==1)
{
scanf("%d",&x);
seg.update(1,1,n,id[x],id[x]+siz[x]-1);
}
else
{
scanf("%d%d",&x,&y);
while (top[x]!=top[y])
{
if (dep[top[x]]<dep[top[y]]) swap(x,y);
seg.update(1,1,n,id[top[x]],id[x]);
x=f[top[x]][0];
}
if (dep[x]<dep[y]) swap(x,y);
seg.update(1,1,n,id[y],id[x]);
}
ll sum=seg.sum[1]/2+1;
x=seg.query2(1,1,n,sum);
if (seg.query1(1,1,n,id[x],id[x]+siz[x]-1)>=sum) printf("%d\n",x);
else
{
for (int i=LG;i>=0;i--)
{
if (!f[x][i]) continue;
int p=f[x][i];
if (seg.query1(1,1,n,id[p],id[p]+siz[p]-1)<sum) x=f[x][i];
}
printf("%d\n",f[x][0]);
}
}
return 0;
}