1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
| typedef struct SEGMENT_NODE { int start; int end; int sum; SEGMENT_NODE* left; SEGMENT_NODE* right;
SEGMENT_NODE(int input_start = 0, int input_end = 0, int input_sum = 0, SEGMENT_NODE* input_left = nullptr, SEGMENT_NODE* input_right = nullptr): start(input_start), end(input_end), sum(input_sum), left(input_left), right(input_right){};
~SEGMENT_NODE() { delete left; delete right; left = right = nullptr; }; } segmentNode;
class segmentTree { private: unique_ptr<segmentNode> root_; segmentNode* buildTree(int start, int end, vector<int>& nums); void updateTree(segmentNode* node, int index, int value); int querySum(segmentNode* node, int start, int end);
public: explicit segmentTree(vector<int>& nums) { if (!nums.empty()) root_.reset(buildTree(0, nums.size() - 1, nums)); };
void updateTree(int index, int value); int querySum(int start, int end); };
segmentNode* segmentTree::buildTree(int start, int end, vector<int>& nums) { if (start == end) return new segmentNode(start, end, nums[start]);
auto mid = start + (end - start) / 2; auto left = buildTree(start, mid, nums); auto right = buildTree(mid + 1, end, nums); return new segmentNode(start, end, left->sum + right->sum, left, right); }
void segmentTree::updateTree(int index, int value) { updateTree(root_.get(), index, value); }
void segmentTree::updateTree(segmentNode* node, int index, int value) { if (node == nullptr) return; if (index == node->start && index == node->end) { node->sum = value; return; }
auto mid = node->start + (node->end - node->start) / 2; if (index <= mid) updateTree(node->left, index, value); else updateTree(node->right, index, value);
node->sum = (node->left != nullptr ? node->left->sum : 0) + (node->right != nullptr ? node->right->sum : 0); return; }
int segmentTree::querySum(int start, int end) { if (root_ == nullptr || start > end || start < 0 || end > root_->end) return 0;
return querySum(root_.get(), start, end); }
int segmentTree::querySum(segmentNode* node, int start, int end) { if (node == nullptr) return 0; if (start == node->start && end == node->end) return node->sum;
auto mid = node->start + (node->end - node->start) / 2; if (end <= mid) return querySum(node->left, start, end); else if (start > mid) return querySum(node->right, start, end); else return querySum(node->left, start, mid) + querySum(node->right, mid + 1, end); }
|