{-# OPTIONS -Wall -XScopedTypeVariables #-} module ArrayMergesort ( mergeSort ) where import Data.Array.ST import Control.Monad.ST import Data.Array.IArray(elems) import qualified Data.Ord as Ord mergeSort :: (Ord a) => [a] -> [a] mergeSort = mergeSortBy Ord.compare mergeSortBy :: (a -> a -> Ordering) -> [a] -> [a] mergeSortBy cmp xs = elems $ runSTArray $ do arr <- newArray_ (1, length xs) :: ST s (STArray s Int a) mergeAndInsertArray cmp arr xs mergeSortArray cmp arr mergeAndInsertArray :: (a -> a -> Ordering) -> STArray s Int a -> [a] -> ST s () mergeAndInsertArray cmp arr xs = go xs 1 where go (y:z:ys) toIx = case y `cmp` z of GT -> do writeArray arr toIx z writeArray arr (toIx+1) y go ys (toIx+2) _ -> do writeArray arr toIx y writeArray arr (toIx+1) z go ys (toIx+2) go (y:[]) toIx = writeArray arr toIx y go [] _ = return () mergeSortArray :: (a -> a -> Ordering) -> STArray s Int a -> ST s (STArray s Int a) mergeSortArray cmp arr = do (low, high) <- getBounds arr arr' <- newArray_ (low, high) helper 2 high arr arr' where -- -- Note that, a pair consist of two pieces helper pieceSize maxIx fromArr toArr | pieceSize >= maxIx = return fromArr | otherwise = do mergePairs helper (2*pieceSize) maxIx toArr fromArr where pairSize = 2*pieceSize -- One merge iteration. It reads all elements exactly once. mergePairs = go 1 where go ix | ix > maxIx - pairSize -- We need to check that maxLeft & maxRight <= maxIx = mergePair ix (ix+pieceSize) (min (ix+pieceSize-1) maxIx) (min (ix+pairSize-1) maxIx) ix | otherwise = do mergePair ix (ix+pieceSize) (ix+pieceSize-1) (ix+pairSize-1) ix go (ix + pairSize) -- -- In mergePair, we are checking both leftIx > maxLeft and -- rightIx > maxRight each time. Similarly are we getting both leftValue -- and rightValue each time. Not really neccessary. -- We could, in stead, make takeLeft and takeRight functions. -- But benchmarks show it is not faster. I wonder why... mergePair leftIx rightIx maxLeft maxRight toIx | leftIx > maxLeft = copyRest rightIx toIx maxRight | rightIx > maxRight = copyRest leftIx toIx maxRight | otherwise = do lVal <- readArray fromArr leftIx rVal <- readArray fromArr rightIx case lVal `cmp` rVal of GT -> do writeArray toArr toIx rVal mergePair leftIx (rightIx + 1) maxLeft maxRight (toIx + 1) _ -> do writeArray toArr toIx lVal mergePair (leftIx + 1) rightIx maxLeft maxRight (toIx + 1) -- copyRest fromIx toIx maxTo | toIx > maxTo = return () | otherwise = do readArray fromArr fromIx >>= writeArray toArr toIx copyRest (fromIx + 1) (toIx + 1) maxTo