Segment Tree

Leetcode上一道题,给定一个整数数组,要实现:

  • 求[i, j]所有元素的和,0 <= i <= j <= n - 1,sumRange(i, j)
  • 数组的元素会被修改, update(i, val)
  • 这两个函数会被均匀的调用很多次

最简单的方法是求和O(n),修改元素O(1),时间复杂度太大,使用Segment Tree可以将二者的时间复杂度均变为O(logn)

Segment Tree:

  • 叶子节点是输入数组中的所有元素
  • 内部节点是其孩子节点所带信息的merge
  • Segment Tree可以由数组实现,数组索引i的左孩子为2 * i + 1,右孩子为2 * i + 2,父节点在(i - 1) / 2下取整处
  • Segment Tree的高度为$\lceil log_2n \rceil$, 因此该数组的大小为$2 * 2 ^\lceil log_2n\rceil - 1$

一下代码是上述题目的C++递归实现:

 1 #include <math.h>
 2 #include <vector>
 3 #include <iostream>
 4 using namespace std;
 5 
 6 class segTree{
 7 public:
 8     vector<int> tree;
 9     int n;
10     segTree(vector<int>& arr){
11         n = arr.size();
12         int treeSize = 2 * pow(2, ceil(log2(double(n)))) - 1;
13         tree.resize(treeSize);
14         buildSegTree(arr, 0, 0, n - 1);
15     }
16     
17     
18     void buildSegTree(vector<int>& arr, int treeIndex, int lo, int hi){
19         if(lo == hi){
20             tree[treeIndex] = arr[lo];
21             return;
22         }
23         int mid = lo + (hi - lo) / 2;
24         buildSegTree(arr, 2 * treeIndex + 1, lo, mid);
25         buildSegTree(arr, 2 * treeIndex + 2, mid + 1, hi);
26         tree[treeIndex] = merge(tree[2 * treeIndex + 1], tree[2 * treeIndex + 2]);
27     }
28 
29 
30     int querySegTree(int treeIndex, int lo, int hi, int i, int j){
31         if(lo > j || hi < i)
32             return 0;
33         if(i <= lo && j >= hi)
34             return tree[treeIndex];
35         
36         int mid = lo + (hi - lo) / 2;
37         
38         if(i > mid)
39             return querySegTree(2 * treeIndex + 2, mid + 1, hi, i, j);
40         else if(j <= mid)
41             return querySegTree(2 * treeIndex + 1, lo, mid, i, j);
42         
43         int leftQuery = querySegTree(2 * treeIndex + 1, lo, mid, i, mid);
44         int rightQuery = querySegTree(2 * treeIndex + 2, mid + 1, hi, mid + 1, j);
45 
46         return merge(leftQuery, rightQuery);
47     }
48 
49 
50     void updateValSegTree(int treeIndex, int lo, int hi, int arrIndex, int val){
51         if(lo == hi){
52             tree[treeIndex] = val;
53             return;
54         }
55 
56         int mid = lo + (hi - lo) / 2;
57 
58         if(arrIndex > mid)
59             updateValSegTree(2 * treeIndex + 2, mid + 1, hi, arrIndex, val);
60         else if(arrIndex <= mid)
61             updateValSegTree(2 * treeIndex + 1, lo, mid, arrIndex, val);
62         
63         tree[treeIndex] = merge(tree[2 * treeIndex + 1] , tree[2 * treeIndex + 2]);
64     }
65 
66 
67     int merge(int& v1, int& v2){
68         return v1 + v2;
69     }
70 };
71 
72 
73 int main(){
74     vector<int> arr1 = {1, 3, 5, 7, 9, 11};
75     segTree test(arr1);
76     for(int item : test.tree)
77         cout << item << " ";
78     cout << endl;
79     int sum = test.querySegTree(0, 0, test.n - 1, 0, 2);
80     cout << "sum = " << sum << endl;
81     test.updateValSegTree(0, 0, test.n - 1, 1, 4);
82     int updatedSum  = test.querySegTree(0, 0, test.n - 1, 0, 2);
83     cout << "updated sum = " << updatedSum << endl;
84     return 0;
85 }

 

上一篇:LeetCode 4.寻找两个有组数组的中位数


下一篇:数据结构学习第十一天