TheRiver | blog

You have reached the world's edge, none but devils play past here

0%

线段树 segm tree

线段树(英语:Segment tree)是一种二叉树形数据结构,1977年由Jon Louis Bentley发明,用以存储区间或线段,并且允许快速查询结构内包含某一点的所有区间。

一个包含n个区间的线段树,空间复杂度为O(n),查询的时间复杂度则为O(logn + k),其中k是符合条件的区间数量。

code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
typedef struct SEGMENT_NODE
{
int start;
int end;
int sum; // can be max/min
SEGMENT_NODE* left;
SEGMENT_NODE* right;

SEGMENT_NODE(int input_start = 0, int input_end = 0, int input_sum = 0, SEGMENT_NODE* input_left =
nullptr, SEGMENT_NODE* input_right = nullptr):
start(input_start),
end(input_end),
sum(input_sum),
left(input_left),
right(input_right){};

~SEGMENT_NODE()
{
delete left;
delete right;
left = right = nullptr;
};
} segmentNode;

class segmentTree
{
private:
unique_ptr<segmentNode> root_;
segmentNode* buildTree(int start, int end, vector<int>& nums);
void updateTree(segmentNode* node, int index, int value);
int querySum(segmentNode* node, int start, int end);

public:
explicit segmentTree(vector<int>& nums)
{
if (!nums.empty())
root_.reset(buildTree(0, nums.size() - 1, nums));
};

void updateTree(int index, int value);
int querySum(int start, int end);
};

segmentNode* segmentTree::buildTree(int start, int end, vector<int>& nums)
{
if (start == end)
return new segmentNode(start, end, nums[start]);

auto mid = start + (end - start) / 2;
auto left = buildTree(start, mid, nums);
auto right = buildTree(mid + 1, end, nums);
return new segmentNode(start, end, left->sum + right->sum, left, right);
}

void segmentTree::updateTree(int index, int value)
{
updateTree(root_.get(), index, value);
}

void segmentTree::updateTree(segmentNode* node, int index, int value)
{
if (node == nullptr)
return;
if (index == node->start && index == node->end)
{
node->sum = value;
return;
}

auto mid = node->start + (node->end - node->start) / 2;

if (index <= mid)
updateTree(node->left, index, value);
else
updateTree(node->right, index, value);

node->sum = (node->left != nullptr ? node->left->sum : 0) +
(node->right != nullptr ? node->right->sum : 0);
return;
}

int segmentTree::querySum(int start, int end)
{
if (root_ == nullptr || start > end || start < 0 || end > root_->end)
return 0;

return querySum(root_.get(), start, end);
}

int segmentTree::querySum(segmentNode* node, int start, int end)
{
if (node == nullptr)
return 0;
if (start == node->start && end == node->end)
return node->sum;

auto mid = node->start + (node->end - node->start) / 2;
if (end <= mid)
return querySum(node->left, start, end);
else if (start > mid)
return querySum(node->right, start, end);
else
return querySum(node->left, start, mid) + querySum(node->right, mid + 1, end);
}

307 solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class NumArray {
public:
NumArray(vector<int>& nums):segTree_(nums) {
}

void update(int index, int val) {
segTree_.updateTree(index, val);
}

int sumRange(int left, int right) {
return segTree_.querySum(left, right);
}

private:
segmentTree segTree_;
};

reference

https://www.youtube.com/watch?v=rYBtViWXYeI

leetcode307

----------- ending -----------