Skip to content

Commit

Permalink
add a disjoint set module
Browse files Browse the repository at this point in the history
  • Loading branch information
glguy committed Dec 19, 2024
1 parent d00ffdb commit f4b5fe6
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
1 change: 1 addition & 0 deletions common/advent.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ library
Advent.Chinese
Advent.Coord
Advent.Coord3
Advent.DisjointSet
Advent.Fix
Advent.Format
Advent.Group
Expand Down
69 changes: 69 additions & 0 deletions common/src/Advent/DisjointSet.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
{-# Language BlockArguments #-}
{-|
Module : Advent.DisjointSet
Description : Implementation of a Disjoint Set datastructure
Copyright : (c) Eric Mertens, 2024
License : ISC
Maintainer : emertens@gmail.com
-}
module Advent.DisjointSet (
DisjointSet,
newDisjointSet,
unifySets,
setSize,
setRepresentative,
inSameSet,
) where

import Control.Monad ( when )
import Data.Array.IO (Ix(range), IOArray, newListArray, readArray, writeArray)

newtype DisjointSet a = DS (IOArray a (Int, a))

newDisjointSet :: Ix a => (a, a) -> IO (DisjointSet a)
newDisjointSet b =
do arr <- newListArray b [(1, x) | x <- range b]
pure (DS arr)

findRoot' :: Ix a => DisjointSet a -> a -> IO (Int, a)
findRoot' (DS arr) x =
do (sz, y) <- readArray arr x
if x == y then pure (sz, x) else findRoot' (DS arr) y

updateRoot :: Ix a => DisjointSet a -> a -> a -> IO ()
updateRoot (DS arr) root x =
when (root /= x)
do (sz, y) <- readArray arr x
writeArray arr x (sz, root)
updateRoot (DS arr) root y

findRoot :: Ix a => DisjointSet a -> a -> IO (Int, a)
findRoot ds x =
do (rank, root) <- findRoot' ds x
updateRoot ds root x
pure (rank, root)

setRepresentative :: Ix a => DisjointSet a -> a -> IO a
setRepresentative ds x = snd <$> findRoot ds x

setSize :: Ix a => DisjointSet a -> a -> IO Int
setSize ds x =
do (size, _) <- findRoot ds x
pure size

unifySets :: Ix a => DisjointSet a -> a -> a -> IO ()
unifySets (DS arr) x y =
do (sizeX, x') <- findRoot (DS arr) x
(sizeY, y') <- findRoot (DS arr) y

when (x' /= y')
if sizeX < sizeY
then writeArray arr x' (0,y') >> writeArray arr y' (sizeX + sizeY, y')
else writeArray arr y' (0,x') >> writeArray arr x' (sizeX + sizeY, x')

inSameSet :: Ix a => DisjointSet a -> a -> a -> IO Bool
inSameSet ds x y =
do x' <- setRepresentative ds x
y' <- setRepresentative ds y
pure (x' == y')

0 comments on commit f4b5fe6

Please sign in to comment.