stripe / rainier

Bayesian inference in Scala.
https://rainier.fit
Apache License 2.0
433 stars 51 forks source link

Make Generator's Monad instance stack safe #366

Closed sritchie closed 5 years ago

sritchie commented 5 years ago

This PR changes the implementation of Generator[T] to a sealed trait with two concrete instances - Constant[T] and From[T]. This allowed me to rewrite the tailRecM method of the monad instance to make it stack safe.

Thanks to cats, adjusting the tests was super easy! I just changed .stackUnsafeMonad to .monad (after verifying that it failed before my changes) and, 💥 , all was well.

I rewrote much of Generator.scala to use Concrete[T] and From[T] instead of creating anonymous instances. I'll comment below on the changes.

avi-stripe commented 5 years ago

@sritchie so great to see you contributing!

sritchie commented 5 years ago

@avi-stripe , here's one I don't understand. My change is causing an error in MY code at this place:

https://github.com/sritchie/rainier/blob/sritchie/stack_safe_generator/rainier-core/src/main/scala/com/stripe/rainier/core/Categorical.scala#L38

could this have something to do with the extra requirements we're now pulling through, at least in the Const case?

avi-stripe commented 5 years ago

@sritchie are we shadowing the scala builtin require method?

sritchie commented 5 years ago

@avi-stripe I dont' think so.

[error] (run-main-8) java.lang.IllegalArgumentException: requirement failed
[error] java.lang.IllegalArgumentException: requirement failed
[error]     at scala.Predef$.require(Predef.scala:268)
[error]     at com.stripe.rainier.core.Categorical.$anonfun$generator$3(Categorical.scala:38)

Gotta run for now... I'll take a look later when I get a chance, with the specific generator instance I was using.

avi-stripe commented 5 years ago

@sritchie gotcha. so I'm sure you're right that it's related to extra requirements, and in particular, changing from the interpreter (which uses BigDecimal) to using compiled floating point ops. So we're losing precision somewhere and are no longer within that 1e-6 epsilon. It would be interesting to know how far off we are; can probably relax it to a larger eps.

sritchie commented 5 years ago

Something major is wrong in this particular case:

// Math.abs(n.toDouble(cdf.last._2) - 1.0) == 0.9
[error] (run-main-c) java.lang.IllegalArgumentException: requirement failed

I'll reproduce with a smaller example, maybe we can track down some other bug.

avi-stripe commented 5 years ago

@sritchie yikes. Looking forward to seeing the repro.

sritchie commented 5 years ago

Okay, @avi-stripe, I've got my repro. I was allowing an empty sequence into the key of a map I passed to Categorical.normalize. As you can see below, the error occurred when I then flatMapped to Categorical.list, to generate items from whichever sequence the generator returned by normalize selected.

The 0.9 error returned was the weight I'd associated with the empty sequence.

So, this error unrelated to this PR or requirements, it turns out, and I can restructure my code to do a better job here.

I'll drop the test file here if you see some fix and want to take it... otherwise I'll call this PR complete.

Thanks!

package com.stripe.rainier.core

import com.stripe.rainier.compute.{Evaluator, Real}
import com.stripe.rainier.sampler.{RNG, ScalaRNG}
import org.scalatest.FunSuite

class CategoricalTest extends FunSuite {
  implicit val rng: RNG = ScalaRNG(1527608515939L)
  implicit val evaluator: Numeric[Real] = new Evaluator(Map.empty)

  test("Categorical.list removes empty items.") {
    val expected = Seq(1, 2, 3)
    val gen = Categorical
      .normalize(
        Map(
          expected.toSeq -> 0.1,
          Seq.empty -> 0.9
        )
      )
      .flatMap(Categorical.list(_))
      .generator

    val choice = gen.get
    assert(
      expected.contains(choice),
      s"The generator returned an unexpected value - $choice - not contained in $expected."
    )
  }
}
avibryant commented 5 years ago

@sritchie got it, yes, this seems unrelated but something we should protect better against. This PR is ready to merge IMO.

sritchie commented 5 years ago

Thanks, @avibryant! When you get a chance would you mind publishing a snapshot build?

avi-stripe commented 5 years ago

@sritchie I think we can do a quick 0.2.3 release later this week

sritchie commented 5 years ago

Thanks @avi-stripe , that sounds great. I need it to get https://github.com/sritchie/scala-rl building on Travis..