composewell / streamly

High performance, concurrent functional programming abstractions
https://streamly.composewell.com
Other
857 stars 64 forks source link

Performance issue when fusing scan and fold #153

Open harendra-kumar opened 5 years ago

harendra-kumar commented 5 years ago

The following code takes 99 us to get the sum and product of a stream of 100000 numbers:

{-# INLINE sumProductFold #-}
sumProductFold :: Monad m => Stream m Int -> m ((Int, Int))
sumProductFold = S.foldl' (\(s,p) x -> (s + x, p P.* x)) (0,1)

If we break it into scan and fold, it takes 9.9 ms which is 100x slower:

{-# INLINE sumProductScan #-}
sumProductScan :: Monad m => Stream m Int -> m ((Int, Int))
sumProductScan = S.foldl' (\(_, p) (s0,x) -> (s0, p * x)) (0,1)
    . S.scanl' (\(s,_) x -> (s + x,x)) (0,0)

Using the sum s0 coming from the scan works well when the accumulator in the fold is not a tuple e.g. if we calculate (p * x + s0) instead of (s0, p * x), taking only 99 us, also even if we use a tuple accumulator but do not use the incoming sum in the tuple it works fine e.g. accumulating (0, p * x) works fine:

{-# INLINE sumProductScan0 #-}
sumProductScan0 :: Monad m => Stream m Int -> m ((Int, Int))
sumProductScan0 = S.foldl' (\(_, p) (s0,x) -> (0, p * x)) (0,1)
    . S.scanl' (\(s,_) x -> (s + x,x)) (0,0)

But as soon as we accumulate (s0, p * x) in the fold, performance goes down by 100x. Need to examine the core and investigate what's going on here.

harendra-kumar commented 5 years ago

The problem was that the accumulator itself was not a strict data structure. I changed that to a strict pair:

data Pair a b = Pair !a !b deriving (Generic, NFData)

{-# INLINE sumProductScan #-}
sumProductScan :: Monad m => Stream m Int -> m (Pair Int Int)
sumProductScan = S.foldl' (\(Pair _  p) (s0,x) -> Pair s0 (p P.* x)) (Pair 0 1)
    . S.scanl' (\(s,_) x -> (s + x,x)) (0,0)

And the performance became identical to fold:

serially/mixed/sum-product-fold          mean 97.81 μs  ( +- 3.458 μs  )
serially/mixed/sum-product-scan0         mean 98.23 μs  ( +- 3.391 μs  )
serially/mixed/sum-product-scan          mean 98.39 μs  ( +- 4.225 μs  )

Is there a way to avoid this pitfall?

pranaysashank commented 4 years ago

An alternative would be to deepseq the data structure, this would require a NFData instance. The drawback would be the repeated traversals on a data structure part of which is probably already strict (for example, in a Tree)

harendra-kumar commented 4 years ago

I also thought of using NFData constraint on the fold accumulator. But I think we need a compile time enforcement not runtime evaluation. If somehow at compile time we can figure out that all leaves of a container are strict and fail the compilation if they are not then it will be helpful.

harendra-kumar commented 4 years ago

Fuzzy thoughts. Can we put a constraint on the type and if that constraint is present we introspect the type to make sure that all leaves either have a bang pattern or are defined in a module having StrictData?