Segment Trees
Tree structure for range queries and updates.
1. Building segment trees
Understanding construction and structure:
Concept:
- Purpose: Efficiently answer range queries and updates
- Structure: Binary tree where each node represents an interval
- Property: Each node stores aggregate information about its range
- Height: O(log n) for n elements
Array representation:
# For array of size n, segment tree needs 4*n space
# Node i has children at 2*i and 2*i+1
# Parent of node i is at i//2
class SegmentTree:
def __init__(self, arr):
self.n = len(arr)
self.tree = [0] * (4 * self.n)
self.arr = arr
if self.n > 0:
self.build(1, 0, self.n - 1)
Building the tree (bottom-up):
def build(self, node, start, end):
"""
Build segment tree recursively
node: current node index in tree array
start, end: range this node represents
"""
if start == end:
# Leaf node
self.tree[node] = self.arr[start]
return
mid = (start + end) // 2
left_child = 2 * node
right_child = 2 * node + 1
# Build left and right subtrees
self.build(left_child, start, mid)
self.build(right_child, mid + 1, end)
# Internal node stores sum of children
self.tree[node] = self.tree[left_child] + self.tree[right_child]
Visual example:
Array: [1, 3, 5, 7, 9, 11]
Segment Tree (sum):
[0-5: 36]
/ \
[0-2: 9] [3-5: 27]
/ \ / \
[0-1: 4] [2: 5] [3-4: 16] [5: 11]
/ \ / \
[0: 1] [1: 3] [3: 7] [4: 9]
Tree array: [_, 36, 9, 27, 4, 5, 16, 11, 1, 3, _, _, 7, 9, _, _]
Indices: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
Iterative build (more efficient):
def build_iterative(self, arr):
"""Build segment tree iteratively"""
n = len(arr)
self.tree = [0] * (2 * n)
# Copy array to second half
for i in range(n):
self.tree[n + i] = arr[i]
# Build tree by calculating parents
for i in range(n - 1, 0, -1):
self.tree[i] = self.tree[2 * i] + self.tree[2 * i + 1]
Time and space:
- Build time: O(n)
- Space: O(4n) = O(n) for recursive, O(2n) for iterative
2. Range sum/min/max queries
Efficiently querying aggregate values over ranges:
Range sum query:
def query_sum(self, node, start, end, left, right):
"""
Query sum in range [left, right]
node: current node
[start, end]: range of current node
[left, right]: query range
"""
# No overlap
if right < start or left > end:
return 0
# Complete overlap
if left <= start and end <= right:
return self.tree[node]
# Partial overlap
mid = (start + end) // 2
left_sum = self.query_sum(2 * node, start, mid, left, right)
right_sum = self.query_sum(2 * node + 1, mid + 1, end, left, right)
return left_sum + right_sum
# Wrapper method
def range_sum(self, left, right):
return self.query_sum(1, 0, self.n - 1, left, right)
Range minimum query:
class SegmentTreeMin:
def __init__(self, arr):
self.n = len(arr)
self.tree = [float('inf')] * (4 * self.n)
self.arr = arr
if self.n > 0:
self.build(1, 0, self.n - 1)
def build(self, node, start, end):
if start == end:
self.tree[node] = self.arr[start]
return
mid = (start + end) // 2
self.build(2 * node, start, mid)
self.build(2 * node + 1, mid + 1, end)
# Store minimum of children
self.tree[node] = min(self.tree[2 * node],
self.tree[2 * node + 1])
def query_min(self, node, start, end, left, right):
if right < start or left > end:
return float('inf')
if left <= start and end <= right:
return self.tree[node]
mid = (start + end) // 2
left_min = self.query_min(2 * node, start, mid, left, right)
right_min = self.query_min(2 * node + 1, mid + 1, end, left, right)
return min(left_min, right_min)
Range maximum query:
class SegmentTreeMax:
def build(self, node, start, end):
if start == end:
self.tree[node] = self.arr[start]
return
mid = (start + end) // 2
self.build(2 * node, start, mid)
self.build(2 * node + 1, mid + 1, end)
# Store maximum of children
self.tree[node] = max(self.tree[2 * node],
self.tree[2 * node + 1])
Query visualization:
Query: sum(2, 4) in array [1, 3, 5, 7, 9, 11]
Tree traversal:
[0-5: 36]
/ \
[0-2: 9] [3-5: 27]
✗ / \
[3-4: 16]✓ [5: 11]✗
/ \
[3: 7]✓ [4: 9]✓
Result: 5 (from [2:5]) + 16 (from [3-4:16]) = 21
Actually: 5 + 7 + 9 = 21 ✓
Time complexity: O(log n) per query
3. Point and range updates
Modifying values in the segment tree:
Point update:
def update_point(self, node, start, end, idx, value):
"""
Update element at index idx to value
"""
if start == end:
# Leaf node
self.arr[idx] = value
self.tree[node] = value
return
mid = (start + end) // 2
if idx <= mid:
# Update left child
self.update_point(2 * node, start, mid, idx, value)
else:
# Update right child
self.update_point(2 * node + 1, mid + 1, end, idx, value)
# Update current node
self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
# Wrapper
def update(self, idx, value):
self.update_point(1, 0, self.n - 1, idx, value)
Range update (naive):
def update_range_naive(self, left, right, delta):
"""
Add delta to all elements in [left, right]
Naive approach: update each element individually
Time: O(n log n) - BAD!
"""
for i in range(left, right + 1):
self.update(i, self.arr[i] + delta)
Range update with difference array:
class SegmentTreeRangeUpdate:
"""
For range updates, use difference array
diff[i] = arr[i] - arr[i-1]
"""
def __init__(self, arr):
self.n = len(arr)
# Build difference array
self.diff = [arr[0]] + [arr[i] - arr[i-1]
for i in range(1, len(arr))]
self.tree = [0] * (4 * self.n)
self.build(1, 0, self.n - 1)
def range_update(self, left, right, delta):
"""Add delta to range [left, right] in O(log n)"""
self.update(left, delta)
if right + 1 < self.n:
self.update(right + 1, -delta)
def point_query(self, idx):
"""Get value at index idx"""
return self.query_sum(1, 0, self.n - 1, 0, idx)
Update visualization:
Update index 2 from 5 to 8 (delta = +3)
Before:
[0-5: 36]
/ \
[0-2: 9] [3-5: 27]
/ \
[0-1: 4] [2: 5]←
After:
[0-5: 39]
/ \
[0-2: 12] [3-5: 27]
/ \
[0-1: 4] [2: 8]←
Updated nodes: [2], [0-2], [0-5] - O(log n) nodes
4. Lazy propagation
Optimizing range updates to O(log n):
Concept:
- Problem: Naive range update is O(n log n)
- Solution: Defer updates using lazy propagation
- Idea: Mark nodes as “lazy” instead of updating immediately
- Propagate: Push lazy values down only when needed
Implementation:
class SegmentTreeLazy:
def __init__(self, arr):
self.n = len(arr)
self.tree = [0] * (4 * self.n)
self.lazy = [0] * (4 * self.n) # Lazy array
self.arr = arr
if self.n > 0:
self.build(1, 0, self.n - 1)
def build(self, node, start, end):
if start == end:
self.tree[node] = self.arr[start]
return
mid = (start + end) // 2
self.build(2 * node, start, mid)
self.build(2 * node + 1, mid + 1, end)
self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
def push_down(self, node, start, end):
"""Propagate lazy value to children"""
if self.lazy[node] != 0:
# Apply lazy value to current node
self.tree[node] += (end - start + 1) * self.lazy[node]
# If not leaf, propagate to children
if start != end:
self.lazy[2 * node] += self.lazy[node]
self.lazy[2 * node + 1] += self.lazy[node]
# Clear lazy value
self.lazy[node] = 0
def update_range(self, node, start, end, left, right, delta):
"""Add delta to range [left, right]"""
# Push down pending updates
self.push_down(node, start, end)
# No overlap
if right < start or left > end:
return
# Complete overlap
if left <= start and end <= right:
# Mark as lazy instead of updating
self.lazy[node] += delta
self.push_down(node, start, end)
return
# Partial overlap
mid = (start + end) // 2
self.update_range(2 * node, start, mid, left, right, delta)
self.update_range(2 * node + 1, mid + 1, end, left, right, delta)
# Update current node after children are updated
self.push_down(2 * node, start, mid)
self.push_down(2 * node + 1, mid + 1, end)
self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
def query_range(self, node, start, end, left, right):
"""Query sum in range [left, right]"""
# Push down pending updates
self.push_down(node, start, end)
# No overlap
if right < start or left > end:
return 0
# Complete overlap
if left <= start and end <= right:
return self.tree[node]
# Partial overlap
mid = (start + end) // 2
left_sum = self.query_range(2 * node, start, mid, left, right)
right_sum = self.query_range(2 * node + 1, mid + 1, end, left, right)
return left_sum + right_sum
# Wrapper methods
def update(self, left, right, delta):
self.update_range(1, 0, self.n - 1, left, right, delta)
def query(self, left, right):
return self.query_range(1, 0, self.n - 1, left, right)
Lazy propagation example:
Array: [1, 2, 3, 4, 5]
Update: add 10 to range [1, 3]
Without lazy (naive):
- Update index 1: O(log n)
- Update index 2: O(log n)
- Update index 3: O(log n)
Total: O(n log n)
With lazy:
- Mark nodes covering [1, 3] as lazy
- Only O(log n) nodes marked
- Updates propagated only when queried
Lazy array after update:
lazy[node covering 1-3] = 10
Children not updated yet!
On next query:
- Push lazy value down to children
- Update happens on-demand
Benefits:
- Range update: O(log n) instead of O(n log n)
- Multiple updates before query: Very efficient
- Trade-off: Slightly more complex code
5. Space/time complexities
Understanding performance characteristics:
Space complexity:
# Recursive representation (array-based)
# For n elements:
# - Tree height: ceil(log₂ n)
# - Nodes at level i: 2^i
# - Total nodes: 2^(h+1) - 1 ≈ 4n
# Example: n = 5
# Height = 3
# Nodes = 2^4 - 1 = 15
# Array size = 4 * 5 = 20 (safe upper bound)
def calculate_tree_size(n):
import math
height = math.ceil(math.log2(n))
return 2 ** (height + 1)
# Iterative representation
# Size = 2 * n (more space-efficient)
Time complexities:
| Operation | Without Lazy | With Lazy |
|---|---|---|
| Build | O(n) | O(n) |
| Point update | O(log n) | O(log n) |
| Point query | O(log n) | O(log n) |
| Range query | O(log n) | O(log n) |
| Range update | O(n log n) | O(log n) |
Comparison with alternatives:
| Structure | Build | Query | Update | Space |
|---|---|---|---|---|
| Array | O(1) | O(n) | O(1) | O(n) |
| Prefix Sum | O(n) | O(1) | O(n) | O(n) |
| Sqrt Decomposition | O(n) | O(√n) | O(√n) | O(n) |
| Segment Tree | O(n) | O(log n) | O(log n) | O(4n) |
| Fenwick Tree | O(n log n) | O(log n) | O(log n) | O(n) |
Memory usage example:
# For 1 million elements
n = 1_000_000
# Segment tree
seg_tree_size = 4 * n * 8 bytes # 32 MB (with lazy: 64 MB)
# Fenwick tree
fenwick_size = n * 8 bytes # 8 MB
# Trade-off: Segment tree uses more memory but more versatile
When to use segment trees:
✓ Need range queries (sum, min, max, GCD, etc.) ✓ Need range updates ✓ Operation is associative ✓ Memory not severely constrained ✓ Need flexibility (any associative operation)
When to use alternatives:
- Fenwick Tree: Only sum/XOR, want less memory
- Sqrt Decomposition: Simpler code, O(√n) acceptable
- Sparse Table: Static array, only queries (no updates)
6. Use cases in CP and databases
Real-world applications:
Competitive Programming:
1. Range sum with updates:
# Problem: Dynamic range sum queries
def solve_range_sum():
arr = [1, 3, 5, 7, 9, 11]
seg_tree = SegmentTreeLazy(arr)
# Query sum in range [1, 4]
print(seg_tree.query(1, 4)) # 3+5+7+9 = 24
# Update: add 10 to range [2, 4]
seg_tree.update(2, 4, 10)
# Query again
print(seg_tree.query(1, 4)) # 3+15+17+19 = 54
2. Range minimum query with updates:
# Problem: Find minimum in range, with updates
class RMQSegmentTree:
# Similar to sum, but use min instead
def build(self, node, start, end):
if start == end:
self.tree[node] = self.arr[start]
return
mid = (start + end) // 2
self.build(2*node, start, mid)
self.build(2*node+1, mid+1, end)
self.tree[node] = min(self.tree[2*node], self.tree[2*node+1])
3. Count elements in range:
# Problem: Count elements in range [L, R] with value in [a, b]
# Solution: Merge sort tree (segment tree of sorted arrays)
class MergeSortTree:
def build(self, node, start, end):
if start == end:
self.tree[node] = [self.arr[start]]
return
mid = (start + end) // 2
self.build(2*node, start, mid)
self.build(2*node+1, mid+1, end)
# Merge sorted arrays
self.tree[node] = self.merge(self.tree[2*node],
self.tree[2*node+1])
def count_in_range(self, node, start, end, left, right, a, b):
"""Count elements in [left,right] with value in [a,b]"""
if right < start or left > end:
return 0
if left <= start and end <= right:
# Binary search in sorted array
return self.count_between(self.tree[node], a, b)
mid = (start + end) // 2
return (self.count_in_range(2*node, start, mid, left, right, a, b) +
self.count_in_range(2*node+1, mid+1, end, left, right, a, b))
4. Maximum subarray sum in range:
class MaxSubarraySegTree:
"""
Each node stores: (max_sum, prefix_sum, suffix_sum, total_sum)
"""
def combine(self, left, right):
max_sum = max(left[0], right[0],
left[2] + right[1]) # Crossing middle
prefix_sum = max(left[1], left[3] + right[1])
suffix_sum = max(right[2], right[3] + left[2])
total_sum = left[3] + right[3]
return (max_sum, prefix_sum, suffix_sum, total_sum)
Database applications:
1. Range aggregations:
-- SQL query: SELECT SUM(value) FROM table WHERE id BETWEEN 100 AND 200
-- Internally: Segment tree for fast range sum
2. Time-series databases:
# Store time-series data with efficient range queries
class TimeSeriesDB:
def __init__(self):
self.seg_tree = SegmentTreeLazy([])
def add_metric(self, timestamp, value):
# Add value at timestamp
pass
def get_avg(self, start_time, end_time):
# Get average in time range
total = self.seg_tree.query(start_time, end_time)
count = end_time - start_time + 1
return total / count
3. Spatial databases:
# 2D segment tree for spatial range queries
class SpatialIndex:
"""
Query: Find all points in rectangle [x1,y1] to [x2,y2]
Solution: 2D segment tree
"""
def range_query_2d(self, x1, y1, x2, y2):
# Query x-dimension segment tree
# Each node has y-dimension segment tree
pass
Other applications:
- Graphics: Rectangle union area, line sweep algorithms
- Computational geometry: Closest pair, intersection queries
- Game development: Collision detection, visibility queries
- Network monitoring: Traffic analysis, anomaly detection
- Financial systems: OHLC (Open-High-Low-Close) queries
Example: LeetCode problems using segment trees:
- Range Sum Query - Mutable (LeetCode 307)
- Count of Range Sum (LeetCode 327)
- Range Module (LeetCode 715)
- Falling Squares (LeetCode 699)
- My Calendar III (LeetCode 732)