joelberkeley / spidr

Accelerated machine learning with dependent types
Apache License 2.0
69 stars 4 forks source link

Monads in Idris are slow #366

Open joelberkeley opened 9 months ago

joelberkeley commented 9 months ago

See Stefan Hoeck's thread in Discord

Unfortunately, the current implementation of ST does not provide the performance benefits we'd want from it, because the resulting monadic code is not further optimized by the Idris compiler. Due to it being opaque from the outside (constructors are not exported), there is no way to convert it to a manually written, unwrapped loop, as it is possible with PrimIO (as well as with an approach making use of linear types). Therefore, I'm afraid ST in the current state is not very useful. I know, this does not answer your actual question... So, I actually did some benchmarking and tested a program similar to the following in five different ways: Using direct recursion without any kind of mutable reference, using IORef, using STRef, using a linear mutable ref (code further below), and finally, using Control.Monad.State. ```Idris module IO import Data.IORef %default total sum : Nat -> IORef Nat -> IO Nat sum 0 ref = readIORef ref sum (S k) ref = do n <- readIORef ref writeIORef ref (S n) sum k ref main : IO () main = newIORef 0 >>= sum 10000000000 >>= printLn ``` The time taken to run 10 billion iterations was as follows: direct recursion: 8.4 s linear mutable refs: 38.4 s STRef: 48.6 s IORef: 48.4 s State Monad: 4 m 8 s As can be seen, plain tail-recursion clearly beats every form of mutable state. I assume (and I'll try and verify this later on) that the linear version comes with a certain overhead because reading a linear mutable reference must return a pair (the result plus the new mutable reference of quantity 1). We can avoid this in this simple example by implementing a modify function as a third primitive. Note: Of all versions of the code tested only the direct recursive one and the linear one were actually tail-recursive and thus stack-safe, which is important for backends such as the JavaScript ones. All others will overflow the stack. As expected, the linear version could be improved by adding a new modify primitive. The time was reduced to 18.5 s. Code for the linear version: ``` module Linear import public Data.Linear %default total -------------------------------------------------------------------------------- -- Linear Mutable Reference -------------------------------------------------------------------------------- data Mut : Type -> Type where [external] %extern prim__newIORef : forall a . a -> %World -> Mut a %extern prim__readIORef : forall a . Mut a -> %World -> a %extern prim__writeIORef : forall a . Mut a -> (1 val : a) -> PrimIO () destroy : (1 _ : %World) -> a -> a destroy %MkWorld x = x set' : a -> Mut a -> Mut a set' y z = let MkIORes () w2 := prim__writeIORef z y %MkWorld in destroy w2 z 0 Ur : Type -> Type Ur = (!*) 0 CRes : Type -> Type -> Type CRes a b = Res a (const b) record MRef (a : Type) where constructor MR mut : Mut a alloc : a -> (1 fun : MRef a -@ Ur b) -> Ur b alloc v f = f (MR $ prim__newIORef v %MkWorld) set : a -> MRef a -@ MRef a set x (MR mut) = MR $ set' x mut get : MRef a -@ CRes a (MRef a) get (MR mut) = prim__readIORef mut %MkWorld # MR mut modify : (a -> a) -> MRef a -@ MRef a modify f (MR mut) = let v := prim__readIORef mut %MkWorld in MR $ set' (f v) mut extract : MRef a -@ Ur a extract (MR mut) = MkBang $ prim__readIORef mut %MkWorld -------------------------------------------------------------------------------- -- Example App -------------------------------------------------------------------------------- sum : Nat -> MRef Nat -@ Ur Nat sum 0 x = extract x sum (S k) x = let x2 := modify S x in sum k x2 main : IO () main = printLn (unrestricted $ alloc 0 (sum 10000000000)) ``` Yeah, unfortunately it does not currently get the same boost from aggressive inlining as would be the case in Haskell. Not even IO is always as fast as it could be, even though it gets some special treatment in the compiler. Key takeaway message: If you want raw speed in Idris, avoid Monads at all cost and stick to plain recursion and pattern matches and (PrimIO or linear types if you need mutable state).
joelberkeley commented 7 months ago

also search for the thread monads are slow in discord #general

joelberkeley commented 1 month ago

possible duplicate of #399

joelberkeley commented 2 weeks ago

also see this thread where stefan suggests using EitherT e IO at maximum, as well as tail recursion over e.g. traverse which is "both slow and not stack safe".

also

IO profits from its special treatment only when a HasIO io => function is inlined (many IO functions in the Prelude are IO-polymorphic and not inlined), so yes, its about polymorphic code mostly.