Portfolio/Learn/Union-Find: The Disjoint Set Data Structure
Algorithms & DSAdvanced

Union-Find: The Disjoint Set Data Structure

Learn the Union-Find structure with path compression and union by rank — near-O(1) per operation. Solve connected components, cycle detection, and Kruskal's MST.

13 min read
March 5, 2026
Union-FindDisjoint SetConnected ComponentsMSTPython

When to Use Union-Find

Union-Find tracks connected components that grow over time via merge operations. It answers 'are X and Y connected?' in near-constant time. Use it for: dynamic connectivity, cycle detection in undirected graphs, Kruskal's minimum spanning tree, and accounts/islands merging problems.

Implementation with Optimizations

Two optimizations make Union-Find nearly O(1) per operation: path compression (flatten the tree during find) and union by rank (attach smaller tree under larger). Together they achieve O(α(n)) amortized — effectively constant for all practical inputs.

python
class UnionFind:
    def __init__(self, n: int):
        self.parent = list(range(n))
        self.rank = [0] * n
        self.components = n

    def find(self, x: int) -> int:
        """Find root with path compression."""
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def union(self, x: int, y: int) -> bool:
        """Union by rank. Returns False if already connected."""
        rx, ry = self.find(x), self.find(y)
        if rx == ry:
            return False

        if self.rank[rx] < self.rank[ry]:
            rx, ry = ry, rx
        self.parent[ry] = rx
        if self.rank[rx] == self.rank[ry]:
            self.rank[rx] += 1

        self.components -= 1
        return True

    def connected(self, x: int, y: int) -> bool:
        return self.find(x) == self.find(y)

Application: Number of Islands (Union-Find)

python
def num_islands_uf(grid: list[list[str]]) -> int:
    """Count islands using Union-Find instead of DFS."""
    if not grid:
        return 0

    rows, cols = len(grid), len(grid[0])
    uf = UnionFind(rows * cols)
    water = 0

    for r in range(rows):
        for c in range(cols):
            if grid[r][c] == "0":
                water += 1
                continue
            # Union with right and down neighbors
            for dr, dc in [(0, 1), (1, 0)]:
                nr, nc = r + dr, c + dc
                if nr < rows and nc < cols and grid[nr][nc] == "1":
                    uf.union(r * cols + c, nr * cols + nc)

    return uf.components - water

Application: Kruskal's Minimum Spanning Tree

python
def kruskal_mst(n: int, edges: list[tuple[int, int, int]]) -> list[tuple[int, int, int]]:
    """Find MST using Kruskal's algorithm. edges = [(weight, u, v)]"""
    edges.sort()  # Sort by weight
    uf = UnionFind(n)
    mst = []

    for weight, u, v in edges:
        if uf.union(u, v):  # Only add if it connects new components
            mst.append((weight, u, v))
            if len(mst) == n - 1:
                break

    return mst

Union-Find vs BFS/DFS for connectivity: use Union-Find when edges arrive incrementally (online) or when you need efficient merging. Use BFS/DFS when you have the full graph and need shortest paths or specific traversal orders.