byteverse / contiguous

Typeclass for array types
Other
19 stars 9 forks source link

`StrictSmallArray`: interop with `Data.Elevator` #65

Open Qqwy opened 1 month ago

Qqwy commented 1 month ago

For my latest project I wanted to test out SmallUnliftedArray.

I wanted to use it with arbitrary a's however (rather than only those for which a user provides a PrimUnlifted instance), so I wrote the following interop between Data.Elevator.Strict (from the data-elevator package) and SmallUnliftedArray, called StrictSmallArray:

(NOTE: This of course depends on PR https://github.com/byteverse/contiguous/pull/64 )

{-# LANGUAGE GHC2021 #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE UnliftedNewtypes #-}
module StrictSmallArray(StrictSmallArray(..)) where

import Prelude hiding (foldl, foldr, foldl', foldr', null, read, length)
import Data.Primitive (SmallArray, PrimArray, Prim)
import Data.Primitive.SmallArray qualified as SmallArray
import Data.Primitive.PrimArray qualified as PrimArray
import Data.Coerce (coerce)
import GHC.Exts (TYPE, Levity(..), RuntimeRep(BoxedRep), SmallArray#, SmallMutableArray#)
import Data.Primitive.Contiguous qualified as Contiguous
import Data.Primitive.Contiguous
import Data.Primitive.Contiguous.Class (Slice(..), MutableSlice(..), ContiguousU(..), Contiguous(..))

import Data.Primitive.Unlifted.Class (PrimUnlifted(..))
import Data.Primitive.Unlifted.SmallArray (SmallUnliftedArray_(..), SmallMutableUnliftedArray_(..))
import Data.Primitive.Unlifted.SmallArray.Primops (SmallUnliftedArray# (SmallUnliftedArray#), SmallMutableUnliftedArray# (SmallMutableUnliftedArray#))
import Data.Kind (Type)
import Data.Elevator (Strict(Strict), UnliftedType)
import Control.DeepSeq (NFData)
import Data.Hashable (Hashable)
import Control.Monad.ST (runST)

-- | Helper newtype to implement `PrimUnlifted` for any datatype
-- to turn it into a `Data.Elevator.Strict`
newtype Strictly a = Strictly {unStrictly :: a}
  deriving newtype (Show, Eq, Ord, Hashable, NFData)

instance PrimUnlifted (Strictly a) where
  type Unlifted (Strictly a) = Strict a
  toUnlifted# (Strictly a) = Strict a
  fromUnlifted# (Strict a) = (Strictly a)

-- | Array type whose elements are guaranteed to be in WHNF.
--
-- An easier to use version of `SmallUnliftedArray`,
-- allowing storage of _any_ `a` by virtue of `Data.Elevator.Strict`
newtype StrictSmallArray a = StrictSmallArray (SmallUnliftedArray_ (Strict a) (Strictly a))
  deriving Show

-- | Mutable array type whose elements are guaranteed to be in WHNF
--
-- An easier to use version of `SmallMutableUnliftedArray`,
-- allowing storage of _any_ `a` by virtue of `Data.Elevator.Strict`
newtype StrictSmallMutableArray s a = StrictSmallMutableArray (SmallMutableUnliftedArray_ (Strict a) s (Strictly a))

-- | Unlifted version of `StrictSmallArray` itself
-- (of kind `UnliftedType`)
newtype StrictSmallArray# (a :: Type)
  = StrictSmallArray# (SmallArray# (Strict a))

-- | Unlifted version of `StrictSmallMutableArray` itself
-- (of kind `UnliftedType`)
newtype StrictSmallMutableArray# s (a :: Type)
  = StrictSmallMutableArray# (SmallMutableArray# s (Strict a))

instance Contiguous.Contiguous StrictSmallArray where
  type Mutable StrictSmallArray  = StrictSmallMutableArray
  type Element StrictSmallArray = Always
  type Sliced StrictSmallArray = Slice StrictSmallArray
  type MutableSliced (StrictSmallArray) = MutableSlice (StrictSmallArray)
  {-# INLINE new #-}
  new n = StrictSmallMutableArray <$> new n
  {-# INLINE replicateMut #-}
  replicateMut n x = StrictSmallMutableArray <$> replicateMut n (Strictly x)
  {-# INLINE shrink #-}
  shrink (StrictSmallMutableArray arr) n = StrictSmallMutableArray <$> shrink arr n
  {-# INLINE empty #-}
  empty = StrictSmallArray empty
  {-# INLINE singleton #-}
  singleton = StrictSmallArray . singleton . Strictly
  {-# INLINE doubleton #-}
  doubleton a b = StrictSmallArray $ doubleton (Strictly a) (Strictly b)
  {-# INLINE tripleton #-}
  tripleton a b c = StrictSmallArray $ tripleton (Strictly a) (Strictly b) (Strictly c)
  {-# INLINE quadrupleton #-}
  quadrupleton a b c d = StrictSmallArray $ quadrupleton (Strictly a) (Strictly b) (Strictly c) (Strictly d)
  {-# INLINE quintupleton #-}
  quintupleton a b c d e = StrictSmallArray $ quintupleton (Strictly a) (Strictly b) (Strictly c) (Strictly d) (Strictly e)
  {-# INLINE sextupleton #-}
  sextupleton a b c d e f = StrictSmallArray $ sextupleton (Strictly a) (Strictly b) (Strictly c) (Strictly d) (Strictly e) (Strictly f)
  {-# INLINE index #-}
  index (StrictSmallArray ary) idx = unStrictly $ index ary idx
  {-# INLINE index# #-}
  index# (StrictSmallArray ary) idx | (# v #) <- index# ary idx = (# unStrictly v #)
  {-# INLINE indexM #-}
  indexM (StrictSmallArray ary) idx = unStrictly <$> indexM ary idx
  {-# INLINE size #-}
  size (StrictSmallArray ary) = size ary
  {-# INLINE sizeMut #-}
  sizeMut (StrictSmallMutableArray ary) = sizeMut ary
  {-# INLINE equals #-}
  equals (StrictSmallArray lhs) (StrictSmallArray rhs) = equals lhs rhs
  {-# INLINE equalsMut #-}
  equalsMut (StrictSmallMutableArray lhs) (StrictSmallMutableArray rhs) = equalsMut lhs rhs
  {-# INLINE rnf #-}
  rnf (StrictSmallArray ary) = rnf ary
  {-# INLINE null #-}
  null (StrictSmallArray ary) = null ary
  {-# INLINE read #-}
  read (StrictSmallMutableArray ary) idx = unStrictly <$> read ary idx
  {-# INLINE write #-}
  write (StrictSmallMutableArray ary) idx x = write ary idx (Strictly x)
  {-# INLINE slice #-}
  slice base offset length = Slice {offset, length, base = unlift base}
  {-# INLINE sliceMut #-}
  sliceMut baseMut offsetMut lengthMut = MutableSlice {offsetMut, lengthMut, baseMut = unliftMut baseMut}
  {-# INLINE toSlice #-}
  toSlice base = Slice {offset = 0, length = size base, base = unlift base}
  {-# INLINE toSliceMut #-}
  toSliceMut baseMut = do
    lengthMut <- sizeMut baseMut
    pure MutableSlice {offsetMut = 0, lengthMut, baseMut = unliftMut baseMut}
  {-# INLINE clone_ #-}
  clone_ (StrictSmallArray ary) offset length = StrictSmallArray $ clone_ ary offset length
  {-# INLINE cloneMut_ #-}
  cloneMut_ (StrictSmallMutableArray ary) offset length = StrictSmallMutableArray <$> cloneMut_ ary offset length
  {-# INLINE copy_ #-}
  copy_ (StrictSmallMutableArray dst) dstOffset (StrictSmallArray src) srcOffset length = copy_ dst dstOffset src srcOffset length
  {-# INLINE copyMut_ #-}
  copyMut_ (StrictSmallMutableArray dst) dstOffset (StrictSmallMutableArray src) srcOffset length = copyMut_ dst dstOffset src srcOffset length
  {-# INLINE freeze_ #-}
  freeze_ (StrictSmallMutableArray ary) offset length = StrictSmallArray <$> freeze_ ary offset length
  {-# INLINE unsafeFreeze #-}
  unsafeFreeze (StrictSmallMutableArray ary) = StrictSmallArray <$> unsafeFreeze ary
  {-# INLINE unsafeShrinkAndFreeze #-}
  unsafeShrinkAndFreeze (StrictSmallMutableArray ary) length = StrictSmallArray <$> unsafeShrinkAndFreeze ary length
  {-# INLINE thaw_ #-}
  thaw_ (StrictSmallArray ary) offset length = StrictSmallMutableArray <$> thaw_ ary offset length
  run = runST -- NOTE: not relying on a manually-written run-st here as modern GHCs inline runST properly.

instance Contiguous.ContiguousU StrictSmallArray where
  type Unlifted StrictSmallArray = StrictSmallArray#
  type UnliftedMut StrictSmallArray = StrictSmallMutableArray#
  {-# INLINE resize #-}
  resize (StrictSmallMutableArray ary) length = StrictSmallMutableArray <$> resize ary length
  {-# INLINE unlift #-}
  unlift (StrictSmallArray (SmallUnliftedArray (SmallUnliftedArray# x))) = StrictSmallArray# x
  {-# INLINE unliftMut #-}
  unliftMut (StrictSmallMutableArray (SmallMutableUnliftedArray (SmallMutableUnliftedArray# x))) = StrictSmallMutableArray# x
  {-# INLINE lift #-}
  lift (StrictSmallArray# x) = StrictSmallArray (SmallUnliftedArray (SmallUnliftedArray# x))
  {-# INLINE liftMut #-}
  liftMut (StrictSmallMutableArray# x) = StrictSmallMutableArray (SmallMutableUnliftedArray (SmallMutableUnliftedArray# x))

I recommend looking at the generated Cmm comparing the following two:

sumLazyArray :: SmallArray Int -> Int
sumLazyArray = foldr' (+) 0

sumStrictArray :: StrictSmallArray Int -> Int
sumStrictArray = foldr' (+) 0

The lazy one looks like:

[$wsumLazyArray_entry() { //  [R2]
         { info_tbls: [(cbEn,
                        label: $wsumLazyArray_info
                        rep: HeapRep static { Fun {arity: 1 fun_type: ArgSpec 5} }
                        srt: Nothing),
                       (cbEE,
                        label: block_cbEE_info
                        rep: StackRep [False, True, True]
                        srt: Nothing)]
           stack_info: arg_space: 8
         }
     {offset
       cbEn:
           if ((Sp + -32) < SpLim) (likely: False) goto cbEo; else goto cbEp;
       cbEo:
           R1 = $wsumLazyArray_closure;
           call (stg_gc_fun)(R2, R1) args: 8, res: 0, upd: 8;
       cbEp:
           _sauT::P64 = R2;
           _sauW::I64 = 0;
           _sauV::I64 = I64[R2 + 8] - 1;
           goto cbEw;
       cbEw:
           if (_sauV::I64 != (-1)) goto cbEG; else goto cbEM;
       cbEG:
           (_cbEC::P64) = call MO_AtomicRead W64 MemOrderAcquire((_sauT::P64 + 16) + (_sauV::I64 << 3));
           I64[Sp - 32] = cbEE;
           R1 = _cbEC::P64;
           P64[Sp - 24] = _sauT::P64;
           I64[Sp - 16] = _sauW::I64;
           I64[Sp - 8] = _sauV::I64;
           Sp = Sp - 32;
           if (R1 & 7 != 0) goto cbEE; else goto cbEH;
       cbEH:
           call (I64[R1])(R1) returns to cbEE, args: 8, res: 8, upd: 8;
       cbEE:
           _sauT::P64 = P64[Sp + 8];
           _sauW::I64 = I64[R1 + 7] + I64[Sp + 16];
           _sauV::I64 = I64[Sp + 24] - 1;
           Sp = Sp + 32;
           goto cbEw;
       cbEM:
           R1 = _sauW::I64;
           call (P64[Sp])(R1) args: 8, res: 0, upd: 8;
     }
 }

And the strict one looks like:

[$wsumStrictArray_entry() { //  [R2]
         { info_tbls: [(cbD8,
                        label: $wsumStrictArray_info
                        rep: HeapRep static { Fun {arity: 1 fun_type: ArgSpec 5} }
                        srt: Nothing)]
           stack_info: arg_space: 8
         }
     {offset
       cbD8:
           _sauC::P64 = R2;
           _sauF::I64 = 0;
           _sauE::I64 = I64[R2 + 8] - 1;
           goto cbDh;
       cbDh:
           if (_sauE::I64 != (-1)) goto cbDn; else goto cbDo;
       cbDn:
           (_cbDq::P64) = call MO_AtomicRead W64 MemOrderAcquire((_sauC::P64 + 16) + (_sauE::I64 << 3));
           _sauF::I64 = I64[_cbDq::P64 + 7] + _sauF::I64;
           _sauE::I64 = _sauE::I64 - 1;
           goto cbDh;
       cbDo:
           R1 = _sauF::I64;
           call (P64[Sp])(R1) args: 8, res: 0, upd: 8;
     }
 }

Notice how not only the 'thunk check' and the indirect jump disappeared, but also all the 'store registers to/restore registers from the stack' code! Huge success!


If you want, we could include StrictSmallArray into Contiguous. Note that data-elevator itself is a tiny library with no dependencies (besides base).

Or if you'd rather not, I might publish this in a separate library.

Qqwy commented 1 month ago

What would your preference be? :relaxed: