rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

Add BetaBinomial Distribution #79

Closed tblazina closed 3 years ago

tblazina commented 3 years ago

Add the BetaBinomial distribution

rlouf commented 3 years ago

Thank you for taking this up! I will also need exhaustive tests before merging this PR (take the Beta distribution as an example). What we want to test for:

tblazina commented 3 years ago

Of course! I plan to write tests, I just pushed this first commit to indicate that I'm working on it, I'll remove the [WIP] whenever I think it's actually ready for review, and I'd be happy to get feedback on any numerical correctness! 😃

tblazina commented 3 years ago

I see that the jax.scipy.stats.betabinom code was just added last week and I was wondering if it makes sense to hold off on this PR until jax 0.2.10 is released rather than duplicating/implementing it here?

rlouf commented 3 years ago

Yes, but JAX has a relatively short release cycle so we should be able to merge quickly; we would have to bump the JAX version manually and check that nothing breaks first.

tblazina commented 3 years ago

Ok sounds good, I'll continue writing the tests and everything and then when JAX 0.2.10 is released and we see it hasn't broke anything, it can be merged. Should the update of jax happen in this PR or separately, do you have a preference?

rlouf commented 3 years ago

Sounds good! Separate would be better.

tblazina commented 3 years ago

Ok great, I'll finish this up, let you know when its ready for review and tag that it needs to be merged only after the PR for updating JAX is merged. 👍

rlouf commented 3 years ago

Is it ready for review?

tblazina commented 3 years ago

Sure! Other than the commented out tests which look at the logpdf correctness which can be tested once jax 0.2.10 is released. I more or less followed the tests in beta_tests.py, not sure if it would also make sense to add some 'edge case' tests as is done in the binomial_tests.py since the way it is implemented now is reusing the _random_binomial function defined in the Binomial distribution code.

tblazina commented 3 years ago

Sorry had a sick kid at home. I'll make the changes and let you know when it's ready to merge!

rlouf commented 3 years ago

No problem, that's life and we're in no hurry! Is the kid feeling better?

tblazina commented 3 years ago

Ya she just started daycare so of course she got sick after 1 day and got us all sick. 🤷‍♂️ And my partner is applying for some professorships so has been very busy with application writing so ... child care has been falling on me until she sends those in. I'm full of excuses for why I have no time 😉

In any case I found a bit of time this morning - I added the changes you requested, they all seem good to me. The tests fail because the PDF correctness tests reply on that jax.scipy.stats.betabinomial class which hasn't been released yet.

rlouf commented 3 years ago

I just bumped JAX's version to 0.2.10, you can rebase to see if the tests pass.

tblazina commented 3 years ago

Ok , updated the shape tests and with jax 0.2.10 the tests pass 👍 Unless yo have any other inputs, I think it's ready to merge.

rlouf commented 3 years ago

Looks great to me! Thank you for taking the time to add this to the library!🙂