stripe / rainier

Bayesian inference in Scala.
https://rainier.fit
Apache License 2.0
433 stars 51 forks source link

Crash in a map over a vector of RVs while multiplying with a RV #499

Closed maciekrt closed 4 years ago

maciekrt commented 4 years ago

When trying to execute the following model:

val (model, rts) = {
        val (onset, cum_pdelay) = data.unzip
        val serialInterval = Gamma(6,1/1.5).latent
        val sigma = Normal(0,0.03).latent.abs
        val Theta0 = Normal(0.1,0.1).latent
        val diffs = Laplace(0,sigma).latentVec(data.size) // Maybe normal would be better
        val ThetaT = Vec.from(diffs.toList.scanLeft(Theta0)(_ + _))
        /*
        If we remove multiplication by a rv serialInterval the crash disappears
        */
        val Rts = ThetaT.map { m =>  
            m*serialInterval + 1
        }
        val inferredYesterday = data.take(data.size-1).map { case (o_t, cp_t) => 
            o_t/cp_t 
        }.zipWithIndex
        val expectedToday = Vec.from(inferredYesterday.zip(cum_pdelay.drop(1))).map { case ((infYest, i), cpd) =>
            val lambda = infYest * cpd * (ThetaT(i).exp)
            Poisson(lambda)
        }
        (Model.observe(onset.drop(1), expectedToday), Rts)
    }

I got the stacktrace (included below) after sampling finished. If we change multiplication by a random variable by any simpler operation everything works fine. The notebook fully reproducing the issue is provided here: https://github.com/maciekrt/covid-19/blob/rainierError/Rainier.ipynb .

CC: @gkossakowski

java.util.NoSuchElementException: key not found: com.stripe.rainier.ir.Param@5e05efff
  scala.collection.MapLike.default(MapLike.scala:235)
  scala.collection.MapLike.default$(MapLike.scala:234)
  scala.collection.AbstractMap.default(Map.scala:63)
  scala.collection.MapLike.apply(MapLike.scala:144)
  scala.collection.MapLike.apply$(MapLike.scala:143)
  scala.collection.AbstractMap.apply(Map.scala:63)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:24)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:55)
  com.stripe.rainier.ir.ExprMethodGenerator.$anonfun$traverse$1(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar(MethodGenerator.scala:45)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar$(MethodGenerator.scala:42)
  com.stripe.rainier.ir.ExprMethodGenerator.storeGlobalVar(ExprMethodGenerator.scala:3)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:52)
  com.stripe.rainier.ir.ExprMethodGenerator.$anonfun$traverse$1(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar(MethodGenerator.scala:45)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar$(MethodGenerator.scala:42)
  com.stripe.rainier.ir.ExprMethodGenerator.storeGlobalVar(ExprMethodGenerator.scala:3)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:51)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:28)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:51)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:28)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:52)
  com.stripe.rainier.ir.ExprMethodGenerator.$anonfun$traverse$1(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar(MethodGenerator.scala:45)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar$(MethodGenerator.scala:42)
  com.stripe.rainier.ir.ExprMethodGenerator.storeGlobalVar(ExprMethodGenerator.scala:3)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:51)
  com.stripe.rainier.ir.ExprMethodGenerator.$anonfun$traverse$1(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar(MethodGenerator.scala:45)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar$(MethodGenerator.scala:42)
  com.stripe.rainier.ir.ExprMethodGenerator.storeGlobalVar(ExprMethodGenerator.scala:3)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:51)
  com.stripe.rainier.ir.ExprMethodGenerator.$anonfun$traverse$1(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar(MethodGenerator.scala:45)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar$(MethodGenerator.scala:42)
  com.stripe.rainier.ir.ExprMethodGenerator.storeGlobalVar(ExprMethodGenerator.scala:3)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:51)
  com.stripe.rainier.ir.ExprMethodGenerator.$anonfun$traverse$1(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar(MethodGenerator.scala:45)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar$(MethodGenerator.scala:42)
  com.stripe.rainier.ir.ExprMethodGenerator.storeGlobalVar(ExprMethodGenerator.scala:3)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:51)
  com.stripe.rainier.ir.ExprMethodGenerator.$anonfun$traverse$1(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar(MethodGenerator.scala:45)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar$(MethodGenerator.scala:42)
  com.stripe.rainier.ir.ExprMethodGenerator.storeGlobalVar(ExprMethodGenerator.scala:3)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:51)
  com.stripe.rainier.ir.ExprMethodGenerator.$anonfun$traverse$1(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar(MethodGenerator.scala:45)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar$(MethodGenerator.scala:42)
  com.stripe.rainier.ir.ExprMethodGenerator.storeGlobalVar(ExprMethodGenerator.scala:3)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:51)
  com.stripe.rainier.ir.ExprMethodGenerator.$anonfun$traverse$1(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar(MethodGenerator.scala:45)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar$(MethodGenerator.scala:42)
  com.stripe.rainier.ir.ExprMethodGenerator.storeGlobalVar(ExprMethodGenerator.scala:3)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:51)
  com.stripe.rainier.ir.ExprMethodGenerator.$anonfun$traverse$1(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar(MethodGenerator.scala:45)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar$(MethodGenerator.scala:42)
  com.stripe.rainier.ir.ExprMethodGenerator.storeGlobalVar(ExprMethodGenerator.scala:3)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:51)
  com.stripe.rainier.ir.ExprMethodGenerator.$anonfun$traverse$1(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar(MethodGenerator.scala:45)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar$(MethodGenerator.scala:42)
  com.stripe.rainier.ir.ExprMethodGenerator.storeGlobalVar(ExprMethodGenerator.scala:3)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:51)
  com.stripe.rainier.ir.ExprMethodGenerator.$anonfun$traverse$1(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar(MethodGenerator.scala:45)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar$(MethodGenerator.scala:42)
  com.stripe.rainier.ir.ExprMethodGenerator.storeGlobalVar(ExprMethodGenerator.scala:3)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:51)
  com.stripe.rainier.ir.ExprMethodGenerator.$anonfun$traverse$1(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar(MethodGenerator.scala:45)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar$(MethodGenerator.scala:42)
  com.stripe.rainier.ir.ExprMethodGenerator.storeGlobalVar(ExprMethodGenerator.scala:3)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:51)
  com.stripe.rainier.ir.ExprMethodGenerator.$anonfun$traverse$1(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar(MethodGenerator.scala:45)
  com.stripe.rainier.ir.MethodGenerator.storeGlobalVar$(MethodGenerator.scala:42)
  com.stripe.rainier.ir.ExprMethodGenerator.storeGlobalVar(ExprMethodGenerator.scala:3)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:34)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:51)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:28)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:51)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:28)
  com.stripe.rainier.ir.ExprMethodGenerator.traverse(ExprMethodGenerator.scala:55)
  com.stripe.rainier.ir.ExprMethodGenerator.<init>(ExprMethodGenerator.scala:16)
  com.stripe.rainier.ir.CompiledFunction$.$anonfun$apply$4(CompiledFunction.scala:79)
  scala.collection.immutable.List.map(List.scala:286)
  com.stripe.rainier.ir.CompiledFunction$.$anonfun$apply$3(CompiledFunction.scala:74)
  scala.collection.TraversableLike.$anonfun$flatMap$1(TraversableLike.scala:245)
  scala.collection.immutable.List.foreach(List.scala:392)
  scala.collection.TraversableLike.flatMap(TraversableLike.scala:245)
  scala.collection.TraversableLike.flatMap$(TraversableLike.scala:242)
  scala.collection.immutable.List.flatMap(List.scala:355)
  com.stripe.rainier.ir.CompiledFunction$.apply(CompiledFunction.scala:72)
  com.stripe.rainier.compute.Compiler.compile(Compiler.scala:29)
  com.stripe.rainier.core.Generator.prepare(Generator.scala:76)
  com.stripe.rainier.core.Generator.prepare$(Generator.scala:59)
  com.stripe.rainier.core.Generator$From.prepare(Generator.scala:110)
  com.stripe.rainier.core.Trace.predict(Trace.scala:35)
  ammonite.$sess.cmd11$Helper.predictRtOnset(cmd11.sc:37)
  ammonite.$sess.cmd15$Helper.$anonfun$res15$1(cmd15.sc:7)
  ammonite.$sess.cmd4$Helper.time(cmd4.sc:3)
  ammonite.$sess.cmd15$Helper.<init>(cmd15.sc:6)
  ammonite.$sess.cmd15$.<init>(cmd15.sc:7)
  ammonite.$sess.cmd15$.<clinit>(cmd15.sc:-1)
avibryant commented 4 years ago

There's a somewhat subtle issue here that later versions of the API will likely fix. The problem is that your Model has no reference to serialInterval anywhere and so serialInterval does not participate in inference. The easy fix is something like this:

val obsModel = Model.observe(onset.drop(1), expectedToday)
val trackSerialInterval = Model.track(Set(serialInterval))
(obsModel.merge(trackSerialInterval), Rts)
gkossakowski commented 4 years ago

Thanks Avi! Your suggestion fixed the problem. The crash in codegen with anonymous IR node hinted at Rainier's problem instead of a user error. I wonder if IR nodes could carry a name tag that would make diagnosing more obvious. Perhaps you could capture variable name automatically, the way (I think) @lihaoyi does in Mill?

On Wed, 29 Apr 2020 at 22:42, Avi Bryant notifications@github.com wrote:

Closed #499 https://github.com/stripe/rainier/issues/499.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/stripe/rainier/issues/499#event-3286411507, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAABA6JEPCBYUEPCIY5RR73RPCGKJANCNFSM4MTRUBRA .

-- gkk

avibryant commented 4 years ago

@gkossakowski do you have any pointers to how that's done in Mill?

gkossakowski commented 4 years ago

I believe with the sourcecode library. It comes with handy examples: https://github.com/lihaoyi/sourcecode#examples It's implemented as a macro but a very, very thin one over compiler's API. It's one of the best examples of using macros in Scala that I don't have any objections to.