rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

Bugfix: Poisson logpdf correctness at x=0 #104

Closed mkretsch327 closed 3 years ago

mkretsch327 commented 3 years ago

I noticed unexpected/incorrect behavior of the Poisson distribution- mainly that dist.Poisson(1).logpdf(0.) evaluated to -jnp.inf, rather than -1.

This PR updates the support constraint that was causing this at this edge case. Behavior from the logpmf function of the poisson distribution in scipy.stats covers the expected behavior of the logpmf function outside of the domain of both x and lmbda.

I've added a couple of new unit tests.

rlouf commented 3 years ago

Thank you for taking the time to submit a PR!

You are right that in this case the decorator is overkill, yet checking that jax.scipy implements the right behavior in the CI is also important; so thanks for keeping the test :)

Thank you for adding a new constraint. Besides checking the support, they'll allow to do automatic transformations of the logpdf and automatically choose the most appropriate sampler.

This looks good to me, merging. Thank you !