typelead / eta

The Eta Programming Language, a dialect of Haskell on the JVM
https://eta-lang.org
BSD 3-Clause "New" or "Revised" License
2.61k stars 145 forks source link

TCO for CPS-style programs via fast exceptions #442

Closed rahulmutt closed 6 years ago

rahulmutt commented 7 years ago
{-# LANGUAGE DeriveGeneric #-}
module Main where

import Data.List(intercalate)
import Data.Maybe(fromJust)
import GHC.Generics(Generic)
import Data.Aeson(FromJSON, decode)

import qualified Data.ByteString.Lazy as B

data Record = Record { hello :: !Int } deriving (Generic, Show)

instance FromJSON Record

main :: IO ()
main = do
  let record i = "{\"hello\":" ++ show i ++ "}"
  let records = "[" ++ intercalate "," (map record [1..200]) ++ "]"
  print $ readRecords $ B.pack $ map (toEnum . fromEnum) records
  where
  readRecords :: B.ByteString -> [Record]
  readRecords bs = fromJust $ decode bs

Investigate what causes such excessive stack usage for small input size.

rahulmutt commented 7 years ago

While the stack trace is too complicated to analyse, I do have a guess as to what may be happening as well as a way to fix it.

Stack growth is controlled as long as each evaluation of a thunk (marked by a call to Closure.evaluate on the stack) finishes relatively quick without evaluating too many more thunks in a nested fashion before getting the result value. There is two cases (there may be more) where this will fail:

Given these observations, here's a lightweight solution that we can implement asap:

  1. Add a new int applyCalls field to StgContext.
  2. Change evaluate to reset context.applyCalls to 0 upon entry.
  3. In all the apply* functions defined for Thunk, Function, and PAP, increment the context.applyCalls upon entry. Just before entering the next function, check if context.applyCalls > X where X is a number we have to tune over time (probably 1000 or so), and if so:
    • Reset context.applyCalls
    • Throw a "fast exception" (exception with no stack trace, let's call it CPSException), passing in the function closure, to unwind the stack to the last thunk evaluation (effectively update frame).
    • Grab the closure from the exception and enter() it. The arguments should've already been loaded by apply and the context object is the same throughout an entire call stack for a given thread (invariant), so we can just go ahead and enter.

The only potential con of implementing this is that calling apply will be slightly slower than before, although since only one of the newly introduced branches is the common case with periodic stack rewindings, the JIT should optimise the common case.

Given the static TCO we do at compile-time + this lightweight dynamic TCO at runtime, it should cover nearly all the cases pretty well and the only cause for a StackOverflowError should be a space leak which can be fixed at the application level with bang patterns.

tomshackell commented 7 years ago

Just a note that there is a repeating pattern in the stack trace for this issue, it's just the repeating block is rather large.

The stack trace starts with:

Exception in thread "main" java.lang.StackOverflowError
    at aeson.data.aeson.parser.Internal$sat_s48ZZ7.thunkEnter(Unknown Source)
    at eta.runtime.thunk.UpdatableThunk.enter(UpdatableThunk.java:18)
    at eta.runtime.stg.Closure.evaluate(Closure.java:24)
    at aeson.data.aeson.parser.Internal$sat_s48ZZC.enter(Unknown Source)
    at eta.runtime.apply.Function.applyV(Function.java:16)
    at eta.runtime.exception.Exception.catch_(Exception.java:135)
    at aeson.data.aeson.parser.Internal$sat_s48ZZD.enter(Unknown Source)
    at eta.runtime.apply.Function.applyV(Function.java:16)
    at base.ghc.IO$unsafeDupablePerformIO.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$lvl29_s48YZZ.thunkEnter(Unknown Source)
    at eta.runtime.thunk.UpdatableThunk.enter(UpdatableThunk.java:18)
    at eta.runtime.stg.Closure.evaluate(Closure.java:24)

and then has this pattern

        at aeson.data.aeson.parser.Internal$$wsucc_s48ZZE.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$lvl28_s48YL.enter(Unknown Source)
    at eta.runtime.apply.Function.applyPPPP(Function.java:255)
    at aeson.data.aeson.parser.Internal$$Lr48NJ$s$wa1.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$$wa4.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$$Lr48P5$wa23.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$$wsucc_s4ASS.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$w3_s4ATY.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$$Lr48P0$wa18.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$$Lr48P1$wa19.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$w3_s4BC8.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$$wa.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$$Lr48P2$wa20.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$lvl29_s4AY3.enter(Unknown Source)
    at eta.runtime.apply.Function.applyPPPP(Function.java:255)
    at aeson.data.aeson.parser.Internal$$Lr48OW$wa17.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$w4_s4B1M.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$sat_s4B41.enter(Unknown Source)
    at eta.runtime.apply.Function.applyPPPP(Function.java:255)
    at aeson.data.aeson.parser.Internal$succ_s4AS3.enter(Unknown Source)
    at eta.runtime.apply.Function.applyPPPP(Function.java:255)
    at aeson.data.aeson.parser.Internal$w4_s4BHV.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$lvl28_s4BFI.enter(Unknown Source)
    at eta.runtime.apply.Function.applyPPPP(Function.java:255)
    at aeson.data.aeson.parser.Internal$lose_s49UM.enter(Unknown Source)
    at eta.runtime.apply.Function.applyPPPPP(Function.java:296)
    at attoparsec.data.attoparsec.internal.Types$sat_s1Y6D.enter(Unknown Source)
    at eta.runtime.apply.Function.applyPPPPP(Function.java:296)
    at aeson.data.aeson.parser.Internal$$wa40_s49U7.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$$wsucc2_s49ZZF.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$sat_s4A17.enter(Unknown Source)
    at eta.runtime.apply.Function.applyPPPP(Function.java:255)
    at aeson.data.aeson.parser.Internal$$wsucc_s497Y.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$succ_s498T.enter(Unknown Source)
    at eta.runtime.apply.Function.applyPPPP(Function.java:255)
    at aeson.data.aeson.parser.Internal$$Lr48NP$wa5.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$$Lr48NV$wa7.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$$s$wa27_s49TX.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$$wsucc_s49TL.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$$Lr48OI$wa15.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$w3_s4BC8.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$$wa.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$w4_s4BLG.enter(Unknown Source)
    at aeson.data.aeson.parser.Internal$sat_s4BNT.enter(Unknown Source)
    at eta.runtime.apply.Function.applyPPPP(Function.java:255)

repeated many times. That said I think it quite likely that the Parser monad is indeed written in a CPS style (many parser monads are).

rahulmutt commented 7 years ago

Thanks for taking a closer look! It may be repeating since it's going through a CPS loop - specifically the loop that will read each JSON array element one-by-one, but nonetheless, it's not something we can optimise away at compile-time (you can tell because of the apply calls which are only used for unknown function calls).

The proposed solution should be able to handle the case above. I spent a little more time thinking about this change and it's a bit more pervasive throughout the runtime - the runtime primitive operatives need to be carefully dealt with for example, but it's still not too difficult to implement. What I'll do is beef up the test/perf suite a little bit before taking a stab at this since it may introduce a lot of new bugs if we're not careful.

By the way, would you be interested in implementing this?

rahulmutt commented 7 years ago

@alexknvl shared a nice paper that captures exactly this idea: http://ac.els-cdn.com/S1571066105804591/1-s2.0-S1571066105804591-main.pdf?_tid=d38a1ba4-9441-11e7-8477-00000aacb35e&acdnat=1504839790_550a9b071a67678f8a3b085584ffbc9e

ChinaXing commented 7 years ago

how does scala solve this ? does jvm team has plan to support tco ?

puffnfresh commented 7 years ago

@ChinaXing Scala does not solve this.

ChinaXing commented 7 years ago

the Jvm plan has TCO :

Looking ahead to Java 9 areas of development would be: A self-tuning JVM, Massive Multicore scalability Hypervisor integration Improved native integration, Metaobject protocol big-data support, Reification, Adding tail calls and continuations, A new meta-object protocol to improve cross language support, multi-tenancy, RAS Resource management for cloud applications, and Heterogeneous compute models.

Read more: http://geeknizer.com/java-8-java-9-features-preview/#ixzz4suSfKHbW

@puffnfresh no, scala should solved.

rahulmutt commented 6 years ago

In the origin/trampoline branch, a new function in Data.Function has been implemented:

trampoline :: a -> a

Effectively what it does is takes the computation that is passed as an argument, and turns on the tail-call optimization. The tail call optimization has a minor overhead:

This implementation is highly experimental, but it is opt-in (you must use the trampoline function to activate this optimization at runtime).

This function allows general CPS-style programs to run without blowing the stack. As a test case, the program mentioned at the start of this issue was used and the following changes had to be made to attoparsec:

@@ -41,6 +41,7 @@ module Data.Attoparsec.ByteString.Lazy

 import Control.DeepSeq (NFData(rnf))
 import Data.ByteString.Lazy.Internal (ByteString(..), chunk)
+import Data.Function (trampoline)
 import Data.List (intercalate)
 import qualified Data.ByteString as B
 import qualified Data.Attoparsec.ByteString as A
@@ -86,13 +87,13 @@ instance Functor Result where
 -- | Run a parser and return its result.
 parse :: A.Parser a -> ByteString -> Result a
 parse p s = case s of
-              Chunk x xs -> go (A.parse p x) xs
-              empty      -> go (A.parse p B.empty) empty
+              Chunk x xs -> go (trampoline $ A.parse p x) xs
+              empty      -> go (trampoline $ A.parse p B.empty) empty
   where
     go (T.Fail x stk msg) ys      = Fail (chunk x ys) stk msg
     go (T.Done x r) ys            = Done (chunk x ys) r
-    go (T.Partial k) (Chunk y ys) = go (k y) ys
-    go (T.Partial k) empty        = go (k B.empty) empty
+    go (T.Partial k) (Chunk y ys) = go (trampoline $ k y) ys
+    go (T.Partial k) empty        = go (trampoline $ k B.empty) empty

 -- | Run a parser and print its result to standard output.
 parseTest :: (Show a) => A.Parser a -> ByteString -> IO ()

and aeson:

@@ -82,6 +82,7 @@ import Control.Applicative (Alternative(..))
 import Control.DeepSeq (NFData(..))
 import Control.Monad (MonadPlus(..), ap)
 import Data.Char (isLower, isUpper, toLower, isAlpha, isAlphaNum)
+import Data.Function (trampoline)
 import Data.Data (Data)
 import Data.Foldable (foldl')
 import Data.HashMap.Strict (HashMap)
@@ -423,23 +424,23 @@ emptyObject = Object H.empty

 -- | Run a 'Parser'.
 parse :: (a -> Parser b) -> a -> Result b
-parse m v = runParser (m v) [] (const Error) Success
+parse m v = trampoline $ runParser (m v) [] (const Error) Success
 {-# INLINE parse #-}

 -- | Run a 'Parser'.
 iparse :: (a -> Parser b) -> a -> IResult b
-iparse m v = runParser (m v) [] IError ISuccess
+iparse m v = trampoline $ runParser (m v) [] IError ISuccess
 {-# INLINE iparse #-}

 -- | Run a 'Parser' with a 'Maybe' result type.
 parseMaybe :: (a -> Parser b) -> a -> Maybe b
-parseMaybe m v = runParser (m v) [] (\_ _ -> Nothing) Just
+parseMaybe m v = trampoline $ runParser (m v) [] (\_ _ -> Nothing) Just
 {-# INLINE parseMaybe #-}

 -- | Run a 'Parser' with an 'Either' result type.  If the parse fails,
 -- the 'Left' payload will contain an error message.
 parseEither :: (a -> Parser b) -> a -> Either String b
-parseEither m v = runParser (m v) [] onError Right
+parseEither m v = trampoline $ runParser (m v) [] onError Right
   where onError path msg = Left (formatError path msg)
 {-# INLINE parseEither #-}

As you can see, the changes are minor - slap on a trampoline function in front of a computation that will run a CPS computation that will eventually return a value.

All these patches will be merged to master and appropriate conditional directives will be used to keep these implementations backwards compatible.

agocorona commented 6 years ago

So a continuation monad can auto-trampoline itself in the monad instance. Is that ok?

m >>= k = Cont $ \c -> runCont m $ \a -> trampoline $ runCont (k a) c

rahulmutt commented 6 years ago

@agocorona That would be setup a lot of trampolines because it's in the bind. The best practice is to use trampoline in the run function for the Monad. So if you had a runCont :: Cont r a -> a then you'd use trampoline in only that function. You should call trampoline at the top-level function for running your monad, which in most cases there are just a handful.

rahulmutt commented 6 years ago

Closing. This has been merged to master and seems to be working in a lot of cases.