[学习笔记]——DSU on tree
模拟赛要结束时,ssw02问sxk:T3 你写的什么?
sxk :我终于把T3调出来了
ssw02 : 666,所以你写的什么?
sxk :我以前还没写过。
ssw02 : ???(自闭)
考试后 sxk : 我不给你说了我写的是 DSU on tree 吗,我以前还没怎么写过 。
ssw02 : ???
算法引入
当你发现一道树上的题,有一些统计性任务,而且不带修改,并且只能进行单点统计之类的,复杂度还只能承受 NlogN ,那么,先别写某些毒瘤树套树,DSU on tree (树上启发式合并) 是一个不错的选择。
毒瘤noip原话 :
说实话"dsu on tree"是个极其有问题的民科叫法吧。。(没有怼人的意思。。)
这东西几十年前就有了啊,那个人自己YY了个链分治就瞎起名字。。
好的,我们切入正题,如何让一个 N^2 级别的暴力子树统计变成一个 NlogN 数据结构的算法
算法流程
我们考虑利用轻重链剖分的性质来对算法复杂度进行优化。(假设你会树剖)
假设我们正在递归处理子树 u 。
我们进行如下操作(乱搞):
1.我们先暴力跑 u 的轻儿子所在的子树,同时我们删除递归的贡献值。
2.然后我们搞 u 的重儿子,这个时候我们要把贡献给算上了。
3.我们发现子树的答案还没有更新,怎么办?再把 u 的轻儿子都递归一遍,同时统计累计贡献 。
4.这时候就可以得出子树的答案了。
5.然后,没有然后了吧,我们好像在递归时就处理了每个节点吧。
实现代码:
void deal( int u , int fa , int val ){
cnt[ col[ u ] ] += val ;
if( cnt[ col[ u ] ] > mx )mx = cnt[ col[ u ] ] , sum = col[ u ] ;
else if( cnt[ col[ u ] ] == mx ) sum += (ll)col[ u ] ;
for( int i = head[ u ] ; i ; i = nex[ i ] ){
if( to[ i ] == fa )continue ;
if( to[ i ] != S )
deal( to[ i ] , u , val ) ;
}
}
void dfs2( int u , int fa , int opt ){
for( int i = head[ u ] ; i ; i = nex[ i ] ){
if( to[ i ] == fa )continue ;
if( to[ i ] != son[ u ] )
dfs2( to[ i ] , u , 0 ) ;//0清除
}
if( son[ u ] )dfs2( son[ u ] , u , 1 ) , S = son[ u ] ;
deal( u , fa , 1 ) ;//递归处理子树,统计轻儿子贡献
ans[ u ] = sum , S = 0 ;
if( !opt )deal( u , fa , -1 ) , sum = 0 , mx = 0 ;//memset上仙
}
复杂度证明
同树链剖分(这里指轻重链剖分,不是长链剖分)。
注意,清除贡献的时候,如果你用了 memset ,emmmm,恭喜您上天了 。
清除代码:
ans[ u ] = sum , S = 0 ;
if( !opt )deal( u , fa , -1 ) , sum = 0 , mx = 0 ;//memset上仙
不信你看:下面是用 memset 的
树上数颜色
给一棵根为1的树,每次询问子树颜色种类数
AC代码:
#include<bits/stdc++.h>
using namespace std ;
#define ll long long
const int MAXN = 100005 ;
inline int read(){
int s=0 ; char g=getchar() ; while(g>'9'||g<'0')g=getchar() ;
while( g>='0'&&g<='9' )s=s*10+g-'0',g=getchar() ; return s ;
}
int N , M , col[ MAXN ] , son[ MAXN ] , size[ MAXN ] , cnt[ MAXN ] ;
int tot = 1 , S , sum , ans[ MAXN ] , head[ MAXN ] , nex[ MAXN*2 ] , to[ MAXN*2 ] ;
void add( int x , int y ){
to[ ++tot ] = y , nex[ tot ] = head[ x ] , head[ x ] = tot ;
}
void dfs( int u , int fa ){
size[ u ]++ ;
for( int i = head[ u ] ; i ; i = nex[ i ] ){
if( to[ i ] == fa )continue ;
dfs( to[ i ] , u ) ;
size[ u ] += size[ to[i] ] ;
if( size[ to[i] ] > size[ son[u] ] )
son[ u ] = to[ i ] ;
}
}
void deal( int u , int fa , int val ){
cnt[ col[ u ] ] += val ;
if( cnt[ col[ u ] ] == 1 )sum++ ;
for( int i = head[ u ] ; i ; i = nex[ i ] ){
if( to[ i ] == fa )continue ;
if( to[ i ] != S )
deal( to[ i ] , u , val ) ;
}
}
void dfs2( int u , int fa , int opt ){
for( int i = head[ u ] ; i ; i = nex[ i ] ){
if( to[ i ] == fa )continue ;
if( to[ i ] != son[ u ] )
dfs2( to[ i ] , u , 0 ) ;//0清除
}
if( son[ u ] )dfs2( son[ u ] , u , 1 ) , S = son[ u ] ;
deal( u , fa , 1 ) ;//递归处理子树,统计轻儿子贡献
ans[ u ] = sum , S = 0 ;
if( !opt )deal( u , fa , -1 ) , sum = 0 ;//memset上仙
}
int main(){
N = read() ; int m1 , m2 ;
for( int i = 1 ; i < N ; ++i ){
m1 = read() , m2 = read() ;
add( m1 , m2 ) , add( m2 , m1 ) ;
}
for( int i = 1 ; i <= N ; ++i )col[ i ] = read() ;
dfs( 1 , 1 ) ;
dfs2( 1 , 1 , 0 ) ;
M = read() ;
for( int i = 1 ; i <= M ; ++i ){
m1 = read() ;
printf("%d\n",ans[ m1 ] ) ;
}
return 0 ;
}
CF600E Lomsat gelral
一棵树有n个结点,每个结点都是一种颜色,每个颜色有一个编号,求树中每个子树的最多的颜色编号的和。
AC代码:
#include<bits/stdc++.h>
using namespace std ;
#define ll long long
const int MAXN = 100005 ;
inline int read(){//备注:代码参考 @自为风月马前卒
int s=0 ; char g=getchar() ; while(g>'9'||g<'0')g=getchar() ;
while( g>='0'&&g<='9' )s=s*10+g-'0',g=getchar() ; return s ;
}
int N , col[ MAXN ] , son[ MAXN ] , size[ MAXN ] , cnt[ MAXN ] ;
int tot = 1 , S , head[ MAXN ] , nex[ MAXN*2 ] , to[ MAXN*2 ] ;
ll sum , mx , ans[ MAXN ] ;
void add( int x , int y ){
to[ ++tot ] = y , nex[ tot ] = head[ x ] , head[ x ] = tot ;
}
void dfs( int u , int fa ){
size[ u ]++ ;
for( int i = head[ u ] ; i ; i = nex[ i ] ){
if( to[ i ] == fa )continue ;
dfs( to[ i ] , u ) ;
size[ u ] += size[ to[i] ] ;
if( size[ to[i] ] > size[ son[u] ] )
son[ u ] = to[ i ] ;
}
}
void deal( int u , int fa , int val ){
cnt[ col[ u ] ] += val ;
if( cnt[ col[ u ] ] > mx )mx = cnt[ col[ u ] ] , sum = col[ u ] ;
else if( cnt[ col[ u ] ] == mx ) sum += (ll)col[ u ] ;
for( int i = head[ u ] ; i ; i = nex[ i ] ){
if( to[ i ] == fa )continue ;
if( to[ i ] != S )
deal( to[ i ] , u , val ) ;
}
}
void dfs2( int u , int fa , int opt ){
for( int i = head[ u ] ; i ; i = nex[ i ] ){
if( to[ i ] == fa )continue ;
if( to[ i ] != son[ u ] )
dfs2( to[ i ] , u , 0 ) ;//0清除
}
if( son[ u ] )dfs2( son[ u ] , u , 1 ) , S = son[ u ] ;
deal( u , fa , 1 ) ;//递归处理子树,统计轻儿子贡献
ans[ u ] = sum , S = 0 ;
if( !opt )deal( u , fa , -1 ) , sum = 0 , mx = 0 ;//memset上仙
}
int main(){
N = read() ; int m1 , m2 ;
for( int i = 1 ; i <= N ; ++i )col[ i ] = read() ;
for( int i = 1 ; i < N ; ++i ){
m1 = read() , m2 = read() ;
add( m1 , m2 ) , add( m2 , m1 ) ;
}
dfs( 1 , 1 ) ;
dfs2( 1 , 1 , 0 ) ;
for( int i = 1 ; i <= N ; ++i )printf("%lld ",ans[ i ] ) ;
return 0 ;
}