composewell / streamly

High performance, concurrent functional programming abstractions
https://streamly.composewell.com
Other
867 stars 66 forks source link

Fusion issue with unfolds in multi-stream operations #1709

Open harendra-kumar opened 2 years ago

harendra-kumar commented 2 years ago

If we generate the streams using unfolds and use those streams in multi-stream operations e.g. isSubsequenceOf then the operation does not fuse, however, if we use the direct implementation of the stream generators without using unfolds then it fuses.

The difference in unfolds is that we have an additional state to inject the seed before we start generating the stream.

unfold :: Applicative m => Unfold m a b -> a -> Stream m b
unfold (Unfold ustep inject) seed = Stream step UnfoldNothing

    where

    {-# INLINE step #-}
    step _ UnfoldNothing = Skip . UnfoldJust <$> inject seed
    step _ (UnfoldJust st) = do

A stream generated by unfoldrM simplifies as follows:

module Main (main, stream) where

import Streamly.Internal.Data.Stream (Stream, unfold)
import qualified Streamly.Internal.Data.Stream as Stream
import qualified Streamly.Internal.Data.Fold as Fold
import qualified Streamly.Internal.Data.Unfold as Unfold

{-# INLINE sourceUnfoldrM #-}
sourceUnfoldrM :: Monad m => Int -> Int -> Stream m Int
sourceUnfoldrM count start = unfold (Unfold.unfoldrM step) start
    where
    step cnt =
        if cnt > start + count
        then return Nothing
        else return (Just (cnt, cnt + 1))

{-# INLINE stream #-}
stream :: Stream IO Int
stream = do
    let value = 100000
     in sourceUnfoldrM value 1
Rec {
-- RHS size: {terms: 53, types: 59, coercions: 34, joins: 1/4}
$wgo_r4fj
  = \ w_s46l @ r_s46m w1_s46o w2_s46q w3_s46r ->
      join {
        exit_Xy s_a39P st_a352
          = case st_a352 of wild_a35t { I# x_a35u ->
            case ># x_a35u 100001# of {
              __DEFAULT ->
                ((w1_s46o
                    wild_a35t
                    (let { w4_s47e = +# x_a35u 1# } in
                     let { w5_s47f = I# w4_s47e } in
                     let { w6_X47j = UnfoldJust w5_s47f } in
                     (\ @ r1_X47o _ w8_X47s _ w10_X47w w11_X47y ->
                        $wgo_r4fj w6_X47j w8_X47s w10_X47w w11_X47y)
                     `cast` <Co:30>))
                 `cast` <Co:2>)
                  s_a39P;
              1# -> (w2_s46q `cast` <Co:2>) s_a39P
            }
            } } in
      case w_s46l of {
        UnfoldNothing -> jump exit_Xy w3_s46r start_r4fh;
        UnfoldJust st_a352 -> jump exit_Xy w3_s46r st_a352
      }
end Rec }

-- RHS size: {terms: 11, types: 21, coercions: 0, joins: 0/0}
stream1_r4fl
  = \ @ r_s46m _ w1_s46o _ w3_s46q w4_s46r ->
      $wgo_r4fj UnfoldNothing w1_s46o w3_s46q w4_s46r

isSubsequenceOf looks like:

{-# INLINE isSubsequenceOf #-}
isSubsequenceOf :: Monad m => Stream m Int -> m Bool
isSubsequenceOf src = Stream.isSubsequenceOf src src

main = do
    r <- isSubsequenceOf stream
    print r
-- RHS size: {terms: 12, types: 6, coercions: 0, joins: 0/0}
main_$s$wgo1
  = \ sc_s40s sc1_s40r sc2_s40q ->
      $wgo_r435
        (UnfoldJust (I# sc2_s40q)) lvl1_r433 (Just (I# sc1_s40r)) sc_s40s

-- RHS size: {terms: 91, types: 32, coercions: 4, joins: 0/0}
$wgo_r435
  = \ ww_s3Xh ww1_s3Xi ww2_s3Xj w_s3Xe ->
      case ww2_s3Xj of wild3_a37n {
        Nothing ->
          case ww_s3Xh of {
            UnfoldNothing -> $wgo_r435 lvl1_r433 ww1_s3Xi Nothing w_s3Xe;
            UnfoldJust st_a352 ->
              case st_a352 of wild_a35t { I# x_a35u ->
              case ># x_a35u 100001# of {
                __DEFAULT ->
                  $wgo_r435
                    (UnfoldJust (I# (+# x_a35u 1#))) ww1_s3Xi (Just wild_a35t) w_s3Xe;
                1# -> ((hPutStr' stdout $fShowBool2 True) `cast` <Co:2>) w_s3Xe
              }
              }
          };
        Just x_a37x ->
          case ww1_s3Xi of {
            UnfoldNothing -> $wgo_r435 ww_s3Xh lvl1_r433 wild3_a37n w_s3Xe;
            UnfoldJust st_a352 ->
            ...

On the other hand if we use the stream unfoldrM operation the core looks like this:

main_$s$wgo1
  = \ sc_s42D sc1_s42C sc2_s42B sc3_s42A ->
      case ># sc2_s42B 100001# of {
        __DEFAULT ->
          case ==# sc1_s42C sc2_s42B of {
            __DEFAULT ->
              main_$s$wgo1 sc_s42D sc1_s42C (+# sc2_s42B 1#) sc3_s42A;
            1# ->
              case ># sc3_s42A 100001# of {
                __DEFAULT ->
                  main_$s$wgo1 sc_s42D sc3_s42A (+# sc2_s42B 1#) (+# sc3_s42A 1#);
                1# -> ((hPutStr' stdout $fShowBool2 True) `cast` <Co:2>) sc_s42D
              }
          };
        1# -> ((hPutStr' stdout $fShowBool4 True) `cast` <Co:2>) sc_s42D
      }
end Rec }

We do have direct srteam implementations for generation operations but we would prefer to generate everything using unfolds. If GHC can fuse it properly that would be possible. Need to further investigate what's going on here.

harendra-kumar commented 2 years ago

Operations affected by this:

Benchmark                                               Prelude.Serial(0)(μs) Data.Stream(1) - Prelude.Serial(0)(%)
------------------------------------------------------- --------------------- -------------------------------------
o-1-space.multi-stream.isSubsequenceOf                                 111.84                              +1284.86
o-1-space.multi-stream.stripPrefix                                     111.91                              +1277.46
o-1-space.multi-stream.isPrefixOf                                      111.83                              +1270.55

o-1-space.exceptions/serial.retryUnknown                              1195.69                               +193.27
o-1-space.exceptions/serial.retryNoneSimple                           1717.68                               +126.67
o-1-space.exceptions/serial.retryNone                                 1605.25                                +40.85

o-1-space.mapping.foldrS                                              3085.84                                +64.95
o-1-space.mapping.foldrSMap                                           3207.70                                +56.39
o-1-space.mapping.foldrT                                              3851.66                                +48.55
o-1-space.mapping.foldrTMap                                           3971.81                                +47.23
o-1-space.elimination.build.Identity.foldrMToListLength                810.75                                +29.70

o-1-space.elimination.uncons                                          1057.74                                +31.26
harendra-kumar commented 2 years ago

The problem is not really an issue with unfolds but a general fusion issue. Unfolds just happen to use a Skip constructor at the beginning of the stream which triggers the issue. If we use a "drop" operation in the starting of the stream we can simulate the same with streams as well. For example, the following code using a stream instead of unfold also runs into a similar fusion issue:

main = do
    r <- isSubsequenceOf (Stream.drop 1 stream)
    print r
harendra-kumar commented 2 years ago

Since these are fold operations with recursive functions, using SPEC and strict arguments seems to do the trick:

isSubsequenceOf :: (Eq a, Monad m) => Stream m a -> Stream m a -> m Bool
isSubsequenceOf (Stream stepa ta) (Stream stepb tb) = go SPEC Nothing' ta tb
  where
    go !_ Nothing' sa sb = do
        r <- stepa defState sa
        case r of
            Yield x sa' -> go SPEC (Just' x) sa' sb
            Skip sa' -> go SPEC Nothing' sa' sb
            Stop -> return True

    go !_ (Just' x) sa sb = do
        r <- stepb defState sb
        case r of
            Yield y sb' ->
                if x == y
                    then go SPEC Nothing' sa sb'
                    else go SPEC (Just' x) sa sb'
            Skip sb' -> go SPEC (Just' x) sa sb'
            Stop -> return False

Generates the following core:

main_$sgo
  = \ sc_s3b3 sc1_s3b2 sc2_s3b1 eta_B0 ->
      case ># sc_s3b3 100000# of {
        __DEFAULT ->
          case ==# sc2_s3b1 sc_s3b3 of {
            __DEFAULT -> main_$sgo (+# sc_s3b3 1#) sc1_s3b2 sc2_s3b1 eta_B0;
            1# ->
              case ># sc1_s3b2 100000# of {
                __DEFAULT ->
                  main_$sgo (+# sc_s3b3 1#) (+# sc1_s3b2 1#) sc1_s3b2 eta_B0;
                1# -> hPutStr2 stdout $fShowBool4 True eta_B0
              }
          };
        1# -> hPutStr2 stdout $fShowBool5 True eta_B0
      }