Disjoint Set / Union-Find
Maintain dynamic connectivity across disjoint sets.
1. Parent array and forest representation
Understanding the fundamental structure:
Concept:
- Disjoint Set: Collection of non-overlapping sets
- Representative: Each set has a representative (root) element
- Parent Array: Each element points to its parent
- Forest: Multiple trees, each representing a set
Basic structure:
class UnionFind:
def __init__(self, n):
# Initially, each element is its own parent (separate set)
self.parent = list(range(n))
self.count = n # Number of disjoint sets
Visual representation:
Initial state (n=5):
parent = [0, 1, 2, 3, 4]
Each element is its own set:
0 1 2 3 4
After union(0, 1) and union(2, 3):
parent = [0, 0, 2, 2, 4]
Forest representation:
0 2 4
| |
1 3
Sets: {0,1}, {2,3}, {4}
Basic operations (naive):
def find(self, x):
"""Find root of element x"""
while self.parent[x] != x:
x = self.parent[x]
return x
def union(self, x, y):
"""Merge sets containing x and y"""
root_x = self.find(x)
root_y = self.find(y)
if root_x != root_y:
self.parent[root_x] = root_y
self.count -= 1
def connected(self, x, y):
"""Check if x and y are in same set"""
return self.find(x) == self.find(y)
Example usage:
uf = UnionFind(5)
# Union operations
uf.union(0, 1) # Merge {0} and {1}
uf.union(2, 3) # Merge {2} and {3}
uf.union(0, 2) # Merge {0,1} and {2,3}
# Check connectivity
print(uf.connected(1, 3)) # True (both in {0,1,2,3})
print(uf.connected(1, 4)) # False (4 is separate)
print(uf.count) # 2 sets: {0,1,2,3}, {4}
2. Find with path compression
Optimizing the find operation:
Problem with naive find:
Worst case: Linear chain
0 → 1 → 2 → 3 → 4
find(4) requires 4 steps
Multiple finds on same element: O(n) each time
Path compression:
- Idea: Make every node point directly to root during find
- Effect: Flattens tree structure
- Benefit: Future finds become O(1)
Implementation:
def find(self, x):
"""Find with path compression"""
if self.parent[x] != x:
# Recursively find root and compress path
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
Visual example:
Before find(4):
0
|
1
|
2
|
3
|
4
After find(4) with path compression:
0
/|\\\
1 2 3 4
All nodes now point directly to root!
Iterative path compression:
def find_iterative(self, x):
"""Iterative find with path compression"""
root = x
# Find root
while self.parent[root] != root:
root = self.parent[root]
# Compress path
while x != root:
next_parent = self.parent[x]
self.parent[x] = root
x = next_parent
return root
Two-pass path compression:
def find_two_pass(self, x):
"""Two-pass: find root, then compress"""
# First pass: find root
root = x
while self.parent[root] != root:
root = self.parent[root]
# Second pass: point all nodes to root
while x != root:
parent = self.parent[x]
self.parent[x] = root
x = parent
return root
Path halving (one-pass optimization):
def find_path_halving(self, x):
"""Point every other node to its grandparent"""
while self.parent[x] != x:
self.parent[x] = self.parent[self.parent[x]]
x = self.parent[x]
return x
3. Union by rank/size
Optimizing the union operation:
Problem with naive union:
Always attach first tree to second:
union(0, 1): 0 → 1
union(0, 2): 0 → 1 → 2
union(0, 3): 0 → 1 → 2 → 3
Creates linear chain (worst case)
Union by rank:
- Rank: Upper bound on tree height
- Rule: Attach smaller rank tree under larger rank tree
- Benefit: Keeps trees balanced
Implementation:
class UnionFindRank:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n # Initial rank is 0
self.count = n
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x, y):
root_x = self.find(x)
root_y = self.find(y)
if root_x == root_y:
return False
# Attach smaller rank tree under larger rank tree
if self.rank[root_x] < self.rank[root_y]:
self.parent[root_x] = root_y
elif self.rank[root_x] > self.rank[root_y]:
self.parent[root_y] = root_x
else:
# Equal rank: attach either way, increase rank
self.parent[root_y] = root_x
self.rank[root_x] += 1
self.count -= 1
return True
Visual example:
Union by rank:
Tree A (rank 1): Tree B (rank 2):
0 2
| / \
1 3 4
union(0, 2): Attach A under B (smaller rank under larger)
Result (rank 2):
2
/|\
3 4 0
|
1
Union by size:
- Size: Number of elements in tree
- Rule: Attach smaller tree under larger tree
- Benefit: Also keeps trees balanced
Implementation:
class UnionFindSize:
def __init__(self, n):
self.parent = list(range(n))
self.size = [1] * n # Initial size is 1
self.count = n
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x, y):
root_x = self.find(x)
root_y = self.find(y)
if root_x == root_y:
return False
# Attach smaller tree under larger tree
if self.size[root_x] < self.size[root_y]:
self.parent[root_x] = root_y
self.size[root_y] += self.size[root_x]
else:
self.parent[root_y] = root_x
self.size[root_x] += self.size[root_y]
self.count -= 1
return True
def get_size(self, x):
"""Get size of set containing x"""
return self.size[self.find(x)]
Comparison:
# Without optimization: O(n) per operation
# With path compression only: O(log n) amortized
# With union by rank only: O(log n) per operation
# With both: O(α(n)) amortized (nearly O(1))
4. Amortized inverse Ackermann complexity
Understanding the remarkable efficiency:
Ackermann function:
# Ackermann function (grows extremely fast)
def A(m, n):
if m == 0:
return n + 1
if n == 0:
return A(m - 1, 1)
return A(m - 1, A(m, n - 1))
# A(0,n) = n+1
# A(1,n) = n+2
# A(2,n) = 2n+3
# A(3,n) = 2^(n+3) - 3
# A(4,n) = 2^2^2^... (n+3 times) - 3
Inverse Ackermann α(n):
# α(n) = minimum m such that A(m, m) ≥ n
#
# α(n) grows EXTREMELY slowly:
# α(1) = 1
# α(3) = 2
# α(7) = 3
# α(2047) = 4
# α(2^2048 - 1) = 5
#
# For all practical n (even n = 10^80), α(n) ≤ 5
Complexity analysis:
# Union-Find with both optimizations:
# - Path compression
# - Union by rank/size
#
# Time complexity:
# - m operations on n elements: O(m · α(n))
# - Amortized per operation: O(α(n))
# - Practically: O(1) per operation
#
# Space complexity: O(n)
Practical implications:
# For n = 10^9 elements:
# α(10^9) ≈ 4
#
# This means operations are essentially O(1)!
#
# Example: 10^9 operations on 10^9 elements
# Theoretical: O(10^9 · 4) ≈ 4 billion operations
# vs naive O(n): O(10^9 · 10^9) = 10^18 operations
#
# Speedup: ~250 million times faster!
5. Applications: Kruskal’s MST, connectivity
Real-world uses of Union-Find:
Kruskal’s Minimum Spanning Tree:
def kruskal_mst(n, edges):
"""
Find minimum spanning tree using Kruskal's algorithm
n: number of vertices
edges: list of (weight, u, v)
"""
# Sort edges by weight
edges.sort()
uf = UnionFind(n)
mst = []
total_weight = 0
for weight, u, v in edges:
# If u and v not connected, add edge
if uf.union(u, v):
mst.append((u, v, weight))
total_weight += weight
# MST complete when n-1 edges added
if len(mst) == n - 1:
break
return mst, total_weight
# Example
edges = [
(1, 0, 1), # (weight, u, v)
(2, 0, 2),
(3, 1, 2),
(4, 1, 3),
(5, 2, 3)
]
mst, weight = kruskal_mst(4, edges)
# MST: [(0,1,1), (0,2,2), (1,3,4)]
# Total weight: 7
Network connectivity:
class NetworkConnectivity:
"""Check if network is fully connected"""
def __init__(self, n):
self.uf = UnionFind(n)
def add_connection(self, u, v):
"""Add connection between nodes u and v"""
self.uf.union(u, v)
def is_connected(self, u, v):
"""Check if u and v can communicate"""
return self.uf.connected(u, v)
def is_fully_connected(self):
"""Check if all nodes are connected"""
return self.uf.count == 1
def count_components(self):
"""Count number of separate networks"""
return self.uf.count
Number of islands (2D grid):
def num_islands(grid):
"""Count islands in 2D grid"""
if not grid:
return 0
rows, cols = len(grid), len(grid[0])
uf = UnionFind(rows * cols)
def get_id(r, c):
return r * cols + c
# Union adjacent land cells
for r in range(rows):
for c in range(cols):
if grid[r][c] == '1':
# Check right neighbor
if c + 1 < cols and grid[r][c + 1] == '1':
uf.union(get_id(r, c), get_id(r, c + 1))
# Check down neighbor
if r + 1 < rows and grid[r + 1][c] == '1':
uf.union(get_id(r, c), get_id(r + 1, c))
# Count unique roots for land cells
islands = set()
for r in range(rows):
for c in range(cols):
if grid[r][c] == '1':
islands.add(uf.find(get_id(r, c)))
return len(islands)
# Example
grid = [
['1', '1', '0', '0', '0'],
['1', '1', '0', '0', '0'],
['0', '0', '1', '0', '0'],
['0', '0', '0', '1', '1']
]
print(num_islands(grid)) # 3 islands
Detect cycle in undirected graph:
def has_cycle(n, edges):
"""Check if undirected graph has cycle"""
uf = UnionFind(n)
for u, v in edges:
# If u and v already connected, adding edge creates cycle
if uf.connected(u, v):
return True
uf.union(u, v)
return False
# Example
edges = [(0, 1), (1, 2), (2, 0)] # Triangle
print(has_cycle(3, edges)) # True
Friend circles:
def find_circle_num(is_connected):
"""
Count number of friend circles
is_connected[i][j] = 1 if person i and j are friends
"""
n = len(is_connected)
uf = UnionFind(n)
for i in range(n):
for j in range(i + 1, n):
if is_connected[i][j] == 1:
uf.union(i, j)
return uf.count
# Example
is_connected = [
[1, 1, 0],
[1, 1, 0],
[0, 0, 1]
]
print(find_circle_num(is_connected)) # 2 circles
Redundant connection:
def find_redundant_connection(edges):
"""Find edge that creates cycle in tree"""
n = len(edges)
uf = UnionFind(n + 1)
for u, v in edges:
if not uf.union(u, v):
# Union failed: u and v already connected
return [u, v]
return []
# Example
edges = [[1,2], [1,3], [2,3]]
print(find_redundant_connection(edges)) # [2,3]
6. Implementation patterns
Common variations and techniques:
Complete optimized implementation:
class UnionFind:
"""Union-Find with all optimizations"""
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n
self.count = n
def find(self, x):
"""Find with path compression"""
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x, y):
"""Union by rank"""
root_x = self.find(x)
root_y = self.find(y)
if root_x == root_y:
return False
if self.rank[root_x] < self.rank[root_y]:
self.parent[root_x] = root_y
elif self.rank[root_x] > self.rank[root_y]:
self.parent[root_y] = root_x
else:
self.parent[root_y] = root_x
self.rank[root_x] += 1
self.count -= 1
return True
def connected(self, x, y):
"""Check connectivity"""
return self.find(x) == self.find(y)
def get_count(self):
"""Get number of disjoint sets"""
return self.count
With component size tracking:
class UnionFindWithSize:
"""Track size of each component"""
def __init__(self, n):
self.parent = list(range(n))
self.size = [1] * n
self.count = n
self.max_size = 1 # Track largest component
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x, y):
root_x = self.find(x)
root_y = self.find(y)
if root_x == root_y:
return False
# Always attach smaller to larger
if self.size[root_x] < self.size[root_y]:
self.parent[root_x] = root_y
self.size[root_y] += self.size[root_x]
self.max_size = max(self.max_size, self.size[root_y])
else:
self.parent[root_y] = root_x
self.size[root_x] += self.size[root_y]
self.max_size = max(self.max_size, self.size[root_x])
self.count -= 1
return True
def get_size(self, x):
"""Get size of component containing x"""
return self.size[self.find(x)]
def get_max_size(self):
"""Get size of largest component"""
return self.max_size
With custom merge logic:
class UnionFindCustom:
"""Union-Find with custom merge function"""
def __init__(self, n, merge_fn=None):
self.parent = list(range(n))
self.rank = [0] * n
self.data = [None] * n # Store custom data
self.merge_fn = merge_fn or (lambda a, b: a)
self.count = n
def set_data(self, x, data):
"""Set data for element x"""
self.data[x] = data
def get_data(self, x):
"""Get data for component containing x"""
return self.data[self.find(x)]
def union(self, x, y):
root_x = self.find(x)
root_y = self.find(y)
if root_x == root_y:
return False
# Merge data using custom function
merged_data = self.merge_fn(self.data[root_x], self.data[root_y])
if self.rank[root_x] < self.rank[root_y]:
self.parent[root_x] = root_y
self.data[root_y] = merged_data
elif self.rank[root_x] > self.rank[root_y]:
self.parent[root_y] = root_x
self.data[root_x] = merged_data
else:
self.parent[root_y] = root_x
self.rank[root_x] += 1
self.data[root_x] = merged_data
self.count -= 1
return True
# Example: Track sum of each component
uf = UnionFindCustom(5, merge_fn=lambda a, b: a + b)
for i in range(5):
uf.set_data(i, i) # Set initial values
uf.union(0, 1) # Merge: data becomes 0+1=1
uf.union(2, 3) # Merge: data becomes 2+3=5
print(uf.get_data(0)) # 1
print(uf.get_data(2)) # 5
Persistent Union-Find (immutable):
class PersistentUnionFind:
"""Immutable Union-Find for time-travel queries"""
def __init__(self, n):
self.history = []
self.parent = list(range(n))
self.rank = [0] * n
def find(self, x, version=-1):
"""Find at specific version"""
if version == -1:
parent = self.parent
else:
parent = self.history[version]['parent']
if parent[x] != x:
return self.find(parent[x], version)
return x
def union(self, x, y):
"""Create new version with union"""
# Save current state
self.history.append({
'parent': self.parent[:],
'rank': self.rank[:]
})
root_x = self.find(x)
root_y = self.find(y)
if root_x != root_y:
if self.rank[root_x] < self.rank[root_y]:
self.parent[root_x] = root_y
elif self.rank[root_x] > self.rank[root_y]:
self.parent[root_y] = root_x
else:
self.parent[root_y] = root_x
self.rank[root_x] += 1
return len(self.history) - 1 # Return version number
Complexity summary:
| Operation | Naive | Path Compression | Union by Rank | Both |
|---|---|---|---|---|
| Find | O(n) | O(log n) amortized | O(log n) | O(α(n)) |
| Union | O(n) | O(log n) amortized | O(log n) | O(α(n)) |
| Space | O(n) | O(n) | O(n) | O(n) |
Where α(n) is the inverse Ackermann function, effectively O(1) for all practical values of n.