Open harendra-kumar opened 5 years ago
The problem was that the accumulator itself was not a strict data structure. I changed that to a strict pair:
data Pair a b = Pair !a !b deriving (Generic, NFData)
{-# INLINE sumProductScan #-}
sumProductScan :: Monad m => Stream m Int -> m (Pair Int Int)
sumProductScan = S.foldl' (\(Pair _ p) (s0,x) -> Pair s0 (p P.* x)) (Pair 0 1)
. S.scanl' (\(s,_) x -> (s + x,x)) (0,0)
And the performance became identical to fold:
serially/mixed/sum-product-fold mean 97.81 μs ( +- 3.458 μs )
serially/mixed/sum-product-scan0 mean 98.23 μs ( +- 3.391 μs )
serially/mixed/sum-product-scan mean 98.39 μs ( +- 4.225 μs )
Is there a way to avoid this pitfall?
An alternative would be to deepseq
the data structure, this would require a NFData
instance. The drawback would be the repeated traversals on a data structure part of which is probably already strict (for example, in a Tree)
I also thought of using NFData constraint on the fold accumulator. But I think we need a compile time enforcement not runtime evaluation. If somehow at compile time we can figure out that all leaves of a container are strict and fail the compilation if they are not then it will be helpful.
Fuzzy thoughts. Can we put a constraint on the type and if that constraint is present we introspect the type to make sure that all leaves either have a bang pattern or are defined in a module having StrictData
?
The following code takes 99 us to get the sum and product of a stream of 100000 numbers:
If we break it into scan and fold, it takes 9.9 ms which is 100x slower:
Using the sum
s0
coming from the scan works well when the accumulator in the fold is not a tuple e.g. if we calculate(p * x + s0)
instead of(s0, p * x)
, taking only 99 us, also even if we use a tuple accumulator but do not use the incoming sum in the tuple it works fine e.g. accumulating(0, p * x)
works fine:But as soon as we accumulate
(s0, p * x)
in the fold, performance goes down by 100x. Need to examine the core and investigate what's going on here.