算法手记 之 数据结构(线段树详解)(POJ 3468)

依然延续第一篇读书笔记,这一篇是基于《ACM/ICPC 算法训练教程》上关于线段树的讲解的总结和修改(这本书在线段树这里Error非常多),但是总体来说这本书关于具体算法的讲解和案例都是不错的。


线段树简介  这是一种二叉搜索树,类似于区间树,是一种描述线段的树形数据结构,也是ACMer必学的一种数据结构,主要用于查询对一段数据的处理和存储查询,对时间度的优化也是较为明显的,优化后的时间复杂为O(logN)。此外,线段树还可以拓展为点树ZWK线段树等等,与此类似的还有树状数组等等。

  例如:要将数组s[]从[i,j]段上的元素均加上b,那么我们通常需要遍历每个元素(s[i],s[i+1]...s[j])并+b,此时使用的操作数为(j-i+1)次,但如果我们在某些情况下只关心[i,j]段内的总和呢,此时我们只需在[i,j]段内总和sum的基础上+b*(j-i+1)就行了,这样的操作数只需要一次。

  再者,若想知道[i,j]段内的和,直接输出此前存储的总和sum,这样比每次查询时都要遍历(j-i+1)个元素要好得多,因此参照树形结构可以引入一种表示一条线段上数据的结构。

  用数组模拟可以直观表述线段树如右图:算法手记 之 数据结构(线段树详解)(POJ 3468)


  具体实现和相应改进Code:

  定义

  每个结点的定义可以暂时如下:

struct Node{
int l, r; //左右端点坐标
int value; //值
}tree[MAXN];

  上面是一种简单直接的表示,但是对于需要经常更新数值的线段树来说,这种定义让线段树时间优化变得优势全无。

  因为如果对每一个[i,j]内的线段上每一个元素+b时,作为一段数据,我们可以+b*(j-i+1),但这一段的子树上的数据又该如何表示呢,难道一直遍历下去直到所有子结点遍历完并更新其中的数据嘛,这明显是个很愚蠢的做法,这样做会使得线段树的效率下降不少。

  我们在结点的定义上引入一个增量add(初始为0),使得每次更新数据时,在该结点及其子树全部更新数据后,再在该结点的增量add上+b,这样在每次查询或更新到它的子结点时,必然会遍历到该结点,此时查询该结点的add是否为0,如果不为0,则将add的值向下传递,更新子树结点上的value。(在需要时才进行更新是一个很好的算法优化)

  因此我们可以改进上面关于结点的定义,最终定义如下:

 /*Tree*/
struct Node{
int l, r; //左右端点坐标
int value; //值
int add; //子树各结点应add的值
}tree[MAXN];

  

  搭建

  那么我们该如何搭建一个线段树呢,我们利用树形结构的思想,不断得二分得到左儿子和右儿子。原结点的value就靠左右儿子的value相加得到。

  具体如下:

 /*从x结点开始扩展线段树*/
void build(int x, int l, int r)
{
tree[x].l = l;
tree[x].r = r;
if (l == r){
tree[x].value = source[l];
return;
}
int mid = (l + r) / ;
build(x * , l, mid);
build(x * + , mid + , r);
tree[x].value = tree[ * x].value + tree[ * x + ].value;
tree[x].add = ;
}

  更新

  此处开始对书上的Code做了修改和改进。

  那么为了进行一段数据上数据的更新,我们在上面已经引入了add增量表示,具体做法如下:

 /*更新-在[l,r]线段上加上m*/
void update(int x, int l, int r, int m)
{
// update
tree[x].value += m*(r - l + );
// Hit!
if (tree[x].l == l && tree[x].r == r){
tree[x].add += m;
return;
}
// add - Transfer
if (tree[x].add){
tree[ * x].add += tree[x].add;
tree[ * x].value += tree[x].add*(tree[ * x].r - tree[ * x].l + );
tree[ * x + ].add += tree[x].add;
tree[ * x + ].value += tree[x].add*(tree[ * x + ].r - tree[ * x + ].l + );
tree[x].add = ;
}
// continue - Search
int mid = (tree[x].l + tree[x].r)/;
if (r <= mid) //[l,r]在mid右侧
update( * x, l, r, m);
else if (l >= mid) //[l,r]在mid左侧
update( * x + , l, r, m);
else{ //[l,r]横跨mid
update( * x, l, mid, m);
update( * x + , mid + , r, m);
}
}

  查询

  也就是查询某段上的数据value

 //最终查询值
int ans = ;
/*查询*/
void query(int x, int l, int r)
{
// Hit!
if (tree[x].l == l && tree[x].r == r)
{
ans += tree[x].value;
return;
}
// add - Transfer
if (tree[x].add){
tree[ * x].add += tree[x].add;
tree[ * x].value += tree[x].add*(tree[ * x].r - tree[ * x].l + );
tree[ * x + ].add += tree[x].add;
tree[ * x + ].value += tree[x].add*(tree[ * x + ].r - tree[ * x + ].l + );
tree[x].add = ;
}
// continue - Search
int mid = (tree[x].l + tree[x].r)/;
if (r <= mid) //[l,r]在mid左侧
query( * x, l, r);
else if (l >= mid) //[l,r]在mid右侧
query( * x + , l, r);
else{ //[l,r]横跨mid
query( * x, l, mid);
query( * x + , mid + , r);
}
}

  Ps:另外对于一个源数组source[MAX],线段树往往所需的空间要稍大一点,大约为4*MAX.

    最少需要空间为2*MAX,最多需要空间为4*MAX


  在POJ上有一个裸线段树例题---POJ3468

  题目大意就是给一个区间上的sum进行两个操作-1,查询,2,区间上每个点完成一次加法。

 //线段处理-线段树
//在一个区间内处理数据的加减和查询-裸线段树
//Memory:6732K Time:1579Ms
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std; #define MAX 100005 int n, q; //n:原数据量 q:查询量
int s[MAX]; //source date
__int64 ans; //查询结果 /*interval_tree*/
struct Node{
int l, r;
__int64 value;
__int64 add;
}tr[*MAX]; //线段树最少需要2*MAX,最多需要4*MAX /*搭建interval-tree*/
void build(int x,int l,int r)
{
tr[x].l = l;
tr[x].r = r;
if (tr[x].l == tr[x].r){ //规模缩减到单个数据
tr[x].value = s[l];
return;
}
int mid = (l + r) / ;
build( * x, l, mid);
build( * x + , mid + , r);
tr[x].value = tr[ * x].value + tr[ * x + ].value;//该结点value由子树结点决定
tr[x].add = ; //Init
} /*更新-从x向下扩展每个结点+m*/
void update(int x,int l,int r,int m)
{
// update
tr[x].value += m*(r - l + );
// Hit
if (tr[x].l == l && tr[x].r == r){
tr[x].add += m;
return;
}
// add - transfer
if (tr[x].add){
tr[ * x].add += tr[x].add;
tr[ * x + ].add += tr[x].add;
tr[ * x].value += tr[x].add*(tr[ * x].r - tr[ * x].l + );
tr[ * x + ].value += tr[x].add*(tr[ * x + ].r - tr[ * x + ].l + );
tr[x].add = ;
}
// Search
int mid = (tr[x].r + tr[x].l) / ; //该段中点
if (r <= mid)
update( * x, l, r, m);
else if (l > mid)
update( * x + , l, r, m);
else{
update( * x, l, mid, m);
update( * x + , mid + , r, m);
}
} /*查询-interval-date*/
void query(int x,int l,int r)
{
// Hit
if (tr[x].l == l && tr[x].r == r){
ans += tr[x].value;
return;
}
// add - transfer
if (tr[x].add){
tr[ * x].add += tr[x].add;
tr[ * x + ].add += tr[x].add;
tr[ * x].value += tr[x].add*(tr[ * x].r - tr[ * x].l + );
tr[ * x + ].value += tr[x].add*(tr[ * x + ].r - tr[ * x + ].l + );
tr[x].add = ;
}
// Search
int mid = (tr[x].r + tr[x].l) / ; //该段中点
if (r <= mid)
query( * x, l, r);
else if (l > mid)
query( * x + , l, r);
else{
query( * x, l, mid);
query( * x + , mid + , r);
}
} int main()
{
scanf("%d%d", &n, &q);
for (int i = ; i <= n; i++)
scanf("%d", &s[i]);
build(, , n); //Creat_interval tree
while (q--)
{
char ch; //command
int low, high, dig;
scanf("\n%c", &ch);
if (ch == 'C'){
scanf("%d%d%d", &low, &high, &dig);
update(, low, high, dig);
}
else if (ch == 'Q'){
ans = ;
scanf("%d%d", &low, &high);
query(, low, high);
printf("%I64d\n", ans);
}
} return ;
}

小墨- -原创


上一篇:黑马程序员—— Java SE(2)


下一篇:light1341 唯一分解定理