线段树合并学习笔记
线段树合并主要(也就是目前我知道的)适用于树上的一些较为复杂详细问题的维护,一般来说要求修改简单且询问少(一次即可)的情况。它可以在比较优秀的复杂度内(\(O(NlogN)\))完成统计。
主要是代码部分。一般来说它的实现都是对于树上每一个节点都开辟一个动态开点的权值线段树,询问时从下向上合并即可。主要有两种方式,一种方法省空间但不可持久(用得比较多),另一种可持续但耗空间。主要说一下前者的代码。
单点修改很简单。动态开点的模板。
void insert(int wh,int l,int r,int pl,int val){
if(l==r){t[wh].data=pl,t[wh].maxn+=val;return;}
if(pl<=mid)insert(lc==0?lc=++cnt:lc,l,mid,pl,val);
else insert(rc==0?rc=++cnt:rc,mid+1,r,pl,val);
pushup(wh);return;
}
最重要的是合并部分,把后面的线段树合并到前面的那棵上。
int merge(int x,int y,int l,int r){
if(!x)return y;if(!y)return x;
if(l==r){t[x].maxn+=t[y].maxn;return x;}
t[x].left=merge(t[x].left,t[y].left,l,mid);
t[x].right=merge(t[x].right,t[y].right,mid+1,r);
pushup(x);return x;
}
其它的没什么好说的。树上差分模板。
代码:
#include<cstdio>
#define zczc
const int N=100010;
const int S=30;
inline void read(int &wh){
wh=0;int f=1;char w=getchar();
while(w<'0'||w>'9'){if(w=='-')f=-1;w=getchar();}
while(w<='9'&&w>='0'){wh=wh*10+w-'0';w=getchar();}
wh*=f;return;}
inline int max(int s1,int s2){return s1<s2?s2:s1;}
inline int min(int s1,int s2){return s1<s2?s1:s2;}
inline void swap(int &s1,int &s2){int s3=s1;s1=s2;s2=s3;return;}
int m,n,head[N],esum,nt[N][S],d[N],lg[N],root[N],cnt;
struct edge{int t,next;}e[N<<1];
inline void add(int fr,int to){
esum++;e[esum].t=to;e[esum].next=head[fr];head[fr]=esum;return;}
void dfs(int wh,int deep,int fa){
d[wh]=deep,nt[wh][0]=fa;
for(int i=1;i<=lg[d[wh]];i++)nt[wh][i]=nt[nt[wh][i-1]][i-1];
for(int i=head[wh],th;i;i=e[i].next){th=e[i].t;if(th==fa)continue;dfs(th,deep+1,wh);}}
int lca(int s1,int s2){
if(d[s1]<d[s2])swap(s1,s2);
for(int i=lg[d[s1]];i>=0;i--)if(d[nt[s1][i]]>=d[s2])s1=nt[s1][i];if(s1==s2)return s1;
for(int i=lg[d[s1]];i>=0;i--)if(nt[s1][i]!=nt[s2][i])s1=nt[s1][i],s2=nt[s2][i];
return nt[s1][0];}
#define lc t[wh].left
#define rc t[wh].right
#define mid (l+r>>1)
struct node{int maxn,data,left,right;}t[N<<5];
inline void pushup(int wh){
if(t[lc].maxn>=t[rc].maxn)t[wh].maxn=t[lc].maxn,t[wh].data=t[lc].data;
else if(t[rc].maxn)t[wh].maxn=t[rc].maxn,t[wh].data=t[rc].data;
else t[wh].data=t[wh].maxn=0;}
void insert(int wh,int l,int r,int pl,int val){
if(l==r){t[wh].data=pl,t[wh].maxn+=val;return;}
if(pl<=mid)insert(lc==0?lc=++cnt:lc,l,mid,pl,val);
else insert(rc==0?rc=++cnt:rc,mid+1,r,pl,val);
pushup(wh);return;}
int merge(int x,int y,int l,int r){
if(!x)return y;if(!y)return x;
if(l==r){t[x].maxn+=t[y].maxn;return x;}
t[x].left=merge(t[x].left,t[y].left,l,mid);
t[x].right=merge(t[x].right,t[y].right,mid+1,r);
pushup(x);return x;}
int an[N];
void solve(int wh,int fa){
for(int i=head[wh],th;i;i=e[i].next){
th=e[i].t;if(th==fa)continue;solve(th,wh);merge(root[wh],root[th],1,N);}
an[wh]=t[root[wh]].maxn>0?t[root[wh]].data:0;return;}
#undef lc
#undef rc
#undef mid
signed main(){
#ifdef zczc
freopen("in.txt","r",stdin);
#endif
for(int i=0;i<N;i++)lg[i]=lg[i>>1]+1;int s1,s2,s3;read(m);read(n);
for(int i=1;i<=m;i++)root[i]=++cnt;
for(int i=1;i<m;i++){read(s1);read(s2);add(s1,s2);add(s2,s1);}dfs(1,1,0);
while(n--){
read(s1);read(s2);read(s3);int l=lca(s1,s2);
insert(root[s1],1,N,s3,1);insert(root[s2],1,N,s3,1);insert(root[l],1,N,s3,-1);
if(nt[l][0]!=0)insert(root[nt[l][0]],1,N,s3,-1);}
solve(1,0);for(int i=1;i<=m;i++)printf("%d\n",an[i]);return 0;}