树链剖分
简单来说就是数据结构在树上的应用。常用的为线段树splay等。(可现在splay还不会敲囧)
重链剖分:
将树上的边分成轻链和重链。
重边为每个节点到它子树最大的儿子的边,其余为轻边。
设(u,v)为轻边,则size(v)<=size(u)/2 (一旦大于了那必然是重边)
也就是一条路径上每增加一条轻边节点个数就会减少一半以上,那么显然根到任意一个节点路径上的轻边条数一定不会超过log(n)(不然节点就没了啊23333)
重链定义为一条极长的连续的且全由重边构成的链。
容易看出重链两两互不相交
而且在一条路径上重链是由一条条轻边隔开的,所以重链的条数也<=log(n)
我们先进行一次dfs可以将每个节点子树的大小记录好
再进行一次dfs可以刷出重边以及构造重链
dfs2过程的意图和实现
显然每个点在且仅在一条重链里,那么对于点信息的维护我们就可以用将这些链首尾相接放到一个大的线段树里面做
按照dfs序将重链插入线段树,具体的实现是给每个点重新赋一个标号,并且记录每个点所在的链的头节点
先遍历一遍子节点找出一个子树最大的儿子,继续拓展下去
对于剩下的儿子以儿子节点为起点重新开始一条重链
对于询问路径上加和\最值等问题,如何将询问的区间转移到线段树上?
首先可以用倍增算法找到两点(x,y)的最近公共祖先t
然后分别对(x,t)和(y,t)两条路径操作
理想状态是x,t在同一条重链中,那样我们就可以直接在线段树上做了
(为啥不在一条重链中就不可以直接做呢。。?)
像上图这种情况,橙色的点是我们要求的区间。
但是显然在处理到第二个点的时候,会优先向右拓展
编号就不是连续的了。只有在一条重链中才能保证编号的连续的。
(这里的编号指的是在dfs2过程中新的编号)
我们为了达到这种理想状态,就要从x节点一点一点向上爬
每次累加x所在重链的头节点到x节点区间内的答案,然后跳过向上的一条轻边,直到最后的x,t在同一条重链中,再统计在这条重链中的答案
由于重链和轻边的条数都不会超过log(n),所以这一步的复杂度也可以粗略估计为O(log(n))
统计答案就是最基础的线段树操作。
另外由于刚开始并不大理解线段树是一个而不是每条重链上一个,所以还去算了一下重链的条数最大值来估计空间。。。
附上非常奇怪的证明
以下所有u节点表示重边起点,v节点表示重边终点,x表示重链条数
以每个非叶子节点为起点都会产生一条重链
每个非叶子的v节点又会在上面基础上减少一条重链的产生
X=非叶子节点数-非叶子节点中的v节点数
非叶子节点中的v节点数=v节点数-叶子节点中的v节点数
V节点数=u节点数=非叶子节点数
非叶子节点中的v节点数=非叶子节点数-叶子节点中的v节点数
X=非叶子节点数-(非叶子节点数-叶子节点中的v节点数)
= 叶子节点中的v节点数
我们只要保证每个叶子节点的父亲都只有一个儿子就可以使每个叶子节点都是v节点
Max(X)=Max(叶子节点中的v节点数)=叶子节点数
BZOJ1036[ZJOI2008]树的统计Count
Description
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 III. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身
多余的分析就不需要了 上面讲树链剖分就是以这道题为例的
感谢黄学长的模板看了一遍就完全理解了 而且写得很漂亮
代码略长,以后可以多写写练练手感
program bzoj1036;
const maxn=;maxm=;
var n,i,j,x,y,q,cnt,t:longint;
ch:char;
ter,next:array[-..maxm]of longint;
deep,pos,size,link,belong,v:array[-..maxn]of longint;
fa:array[-..maxn,-..]of longint;
tr:array[-..*maxn]of record l,r,mx,sum:longint;end; function max(a,b:longint):longint;
begin
if a>b then exit(a) else exit(b);
end; procedure add(x,y:longint);
begin
inc(j);ter[j]:=y;next[j]:=link[x];link[x]:=j;
inc(j);ter[j]:=x;next[j]:=link[y];link[y]:=j;
end; procedure dfs1(p:longint);
var j:longint;
begin
size[p]:=;
for i:= to do
begin
if deep[p]<= << i then break;
fa[p][i]:=fa[fa[p][i-]][i-];
end;
j:=link[p];
while j<> do
begin
if deep[ter[j]]= then
begin
deep[ter[j]]:=deep[p]+;
fa[ter[j]][]:=p;
dfs1(ter[j]);
inc(size[p],size[ter[j]]);
end;
j:=next[j];
end;
end; procedure dfs2(p,chain:longint);
var k,j:longint;
begin
inc(cnt);pos[p]:=cnt;belong[p]:=chain;
k:=;
j:=link[p];
while j<> do
begin
if deep[ter[j]]>deep[p] then
if size[ter[j]]>size[k] then k:=ter[j];
j:=next[j];
end;
if k= then exit;
dfs2(k,chain);
j:=link[p];
while j<> do
begin
if deep[ter[j]]>deep[p] then
if ter[j]<>k then dfs2(ter[j],ter[j]);
j:=next[j];
end;
end; procedure build(p,l,r:longint);
var mid:longint;
begin
tr[p].l:=l;tr[p].r:=r;tr[p].sum:=;tr[p].mx:=-maxlongint;
if l=r then exit;
mid:=(l+r) >> ;
build(p << ,l,mid);
build(p << +,mid+,r);
end; procedure insert(p,loc,x:longint);
var mid:longint;
begin
if (tr[p].l=loc)and(tr[p].r=loc) then
begin
tr[p].sum:=x;tr[p].mx:=x;
exit;
end;
mid:=(tr[p].l+tr[p].r) >> ;
if loc<=mid then insert(p << ,loc,x) else insert(p << +,loc,x);
tr[p].sum:=tr[p << ].sum+tr[p << +].sum;
tr[p].mx:=max(tr[p << ].mx,tr[p << +].mx);
end; function lca(x,y:longint):longint;
var i,tem:longint;
begin
if deep[x]<deep[y] then
begin
tem:=x;x:=y;y:=tem;
end;
if deep[x]<>deep[y] then
begin
i:=trunc(ln(deep[x]-deep[y])/ln());
while deep[x]>deep[y] do
begin
while (deep[x]-deep[y]>= << i) do x:=fa[x][i];
dec(i);
end;
end;
if x=y then exit(x);
i:=trunc(ln(n)/ln());
while fa[x][]<>fa[y][] do
begin
while fa[x][i]<>fa[y][i] do
begin
x:=fa[x][i];y:=fa[y][i];
end;
dec(i);
end;
exit(fa[x][]);
end; function query_sum(p,l,r:longint):longint;
var mid:longint;
begin
if (tr[p].l=l)and(tr[p].r=r) then exit(tr[p].sum);
mid:=(tr[p].l+tr[p].r) >> ;
if r<=mid then exit(query_sum(p << ,l,r)) else
if l>mid then exit(query_sum(p << +,l,r)) else
exit(query_sum(p << ,l,mid)+query_sum(p << +,mid+,r));
end; function query_mx(p,l,r:longint):longint;
var mid:longint;
begin
if (tr[p].l=l)and(tr[p].r=r) then exit(tr[p].mx);
mid:=(tr[p].l+tr[p].r) >> ;
if r<=mid then exit(query_mx(p << ,l,r)) else
if l>mid then exit(query_mx(p << +,l,r)) else
exit(max(query_mx(p << ,l,mid),query_mx(p << +,mid+,r)));
end; function solve_sum(x,y:longint):longint;
var sum:longint;
begin
sum:=;
while belong[x]<>belong[y] do
begin
inc(sum,query_sum(,pos[belong[x]],pos[x]));
x:=fa[belong[x]][];
end;
inc(sum,query_sum(,pos[y],pos[x]));
exit(sum);
end; function solve_mx(x,y:longint):longint;
var mx:longint;
begin
mx:=-maxlongint;
while belong[x]<>belong[y] do
begin
mx:=max(mx,query_mx(,pos[belong[x]],pos[x]));
x:=fa[belong[x]][];
end;
mx:=max(mx,query_mx(,pos[y],pos[x]));
exit(mx);
end; begin
readln(n);
for i:= to n- do
begin
readln(x,y);
add(x,y);
end;
deep[]:=;dfs1();cnt:=;dfs2(,);
build(,,n);
for i:= to n do
begin
read(v[i]);
insert(,pos[i],v[i]);
end;
readln(q);
for i:= to q do
begin
read(ch);
if ch='C' then
begin
readln(ch,ch,ch,ch,ch,x,y);
v[x]:=y;
insert(,pos[x],y);
end else
begin
read(ch);
if ch='M' then
begin
readln(ch,ch,x,y);
t:=lca(x,y);
writeln(max(solve_mx(x,t),solve_mx(y,t)));
end else
begin
readln(ch,ch,x,y);
t:=lca(x,y);
writeln(solve_sum(x,t)+solve_sum(y,t)-v[t]);
end;
end;
end;
end.