Source code for niemads.DisjointSet
#! /usr/bin/env python
try:
from Queue import Queue
except ImportError:
from queue import Queue
[docs]class DisjointSet:
'''``DisjointSet`` class, implemented using the Up-Tree data structure for amortized O(1) find and union operations'''
def __init__(self, initial=None):
'''``DisjointSet`` constructor
Args:
``initial`` (iterable): Elements with which to initialize the ``DisjointSet`` (each element will be in its own set)
'''
self.parent = dict() # parent[u] = parent of node u
self.num_below = dict() # num_below[u] = number of nodes below u (including u) (only current for sentinels)
if initial is not None:
for x in initial:
self.add(x)
def __contains__(self, x):
'''Check if an element ``x`` exists in this ``DisjointSet``
Args:
``x``: The element to check
Returns:
``bool``: ``True`` if ``x`` exists in this ``DisjointSet``, otherwise ``False``
'''
return x in self.parent
def __iter__(self):
''' Iterate over the elements of this ``DisjointSet``'''
for x in self.parent:
yield x
def __len__(self):
'''Return the number of elements in this ``DisjointSet``
Returns:
``int``: The number of elements contained within this ``DisjointSet``
'''
return len(self.parent)
def __str__(self):
'''Return a string representation of this ``DisjointSet``
Returns:
``str``: A string representation of this ``DisjointSet``
'''
return str(self.sets())
[docs] def add(self, x):
'''Add a new element ``x`` to this ``DisjointSet`` as a sentinel node
Args:
``x``: The element to insert
'''
if x in self:
raise ValueError("Node already exists: %s"%x)
self.parent[x] = None; self.num_below[x] = 1
[docs] def remove(self, x):
'''Remove the element ``x`` from this ``DisjointSet``
Args:
``x``: The element to remove
'''
if x not in self:
raise ValueError("Node not found: %s"%x)
p = self.parent[x]
if p is not None:
p = self.parent[x]; self.num_below[p] -= 1
for e in self.parent:
if self.parent[e] == x:
self.parent[e] = p
del self.parent[x]; del self.num_below[x]
[docs] def find(self, x):
'''Return the sentinel node of the element ``x``. Implements path compression along the search
Args:
``x``: The element to find
Returns:
The sentinel node of ``x``
'''
if x not in self:
raise ValueError("Node not found: %s"%x)
explored = Queue(); curr = x
while self.parent[curr] is not None:
explored.put(curr); curr = self.parent[curr]
while not explored.empty():
self.parent[explored.get()] = curr
return curr
[docs] def union(self, x, y):
'''Union the sets containing ``x`` and ``y``. Implements Union-By-Size
Args:
``x``: One of the two elements whose sets will be unioned
``y``: One of the two elements whose sets will be unioned
'''
if x not in self:
raise ValueError("Node not found: %s"%x)
if y not in self:
raise ValueError("Node not found: %s"%y)
sx = self.find(x); sy = self.find(y)
if sx == sy:
return
if self.num_below[sx] > self.num_below[sy]:
self.parent[sy] = sx; self.num_below[sx] += (self.num_below[sy] + 1)
else:
self.parent[sx] = sy; self.num_below[sy] += (self.num_below[sx] + 1)
[docs] def sets(self):
'''Return the sets of this ``DisjointSet``
Returns:
``list`` of ``set``: The sets of this ``DisjointSet``
'''
out_sets = dict()
for x in self.parent:
p = self.parent[x]
if p is None:
p = x
if p not in out_sets:
out_sets[p] = set()
out_sets[p].add(x)
return list(out_sets.values())