rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
326 stars 17 forks source link

Support for other distributions #74

Open Dpananos opened 3 years ago

Dpananos commented 3 years ago

Hi,

Looks like there is support for lots of common distribution. There are a handful of other distributions which are not presently supported but could (fingers crossed) be easily implemented. Looking at [Stan's Function Reference] I see...

Some multivariate distributions:

among other even more exotic distributions.

There are also some composite distributions:

Is there interest in implementing some of these distributions? A good portion of them are special cases (e.g. the chi-square is a special case of the gamma) and so it might be worth implementing them out of convenience, even if they just wrap their more general distributions.

Is there interest in implementing some of these? If so, I'd be up to implement a portion of them (if not only to familiarize myself the the PDFs but then also to work with JAX some more).

rlouf commented 3 years ago

There definitely is an interest in implementing as many distributions as is possible!

I would first rank roughly by how often you think they're used, if that's not the case already (who am I to know?). Then you can pick whichever you like and open a PR to signal you're working on it.

Also, do you mind if I edit your post to add distributions I think are important and pin it in the issue tracker so others can see what's up for grabs?

Dpananos commented 3 years ago

Yea, go ahead and edit if you like.

I recently opened up a PR on jax to add the chisquare. If that goes well, I will make a PR here for chisquare and just go through the list.

tblazina commented 3 years ago

I'd try my hand at the Beta-Binomial, i'll open a PR for it.

tblazina commented 3 years ago

I'd also try my hand at the Pareto, I'll also open a PR for it.

tblazina commented 3 years ago

I'll go ahead and try to do Weibull next. Will open another WIP PR

Update:

I just noticed that the min_weibull function has been implemented in jax.random but the pdf/logpdf of the distributions hasn't been implemented in jax.scipy.stats yet. I could either a) implement it here in mcx or b) open a PR in jax to add it there before adding the distribution here. I suppose b is probably the better option.

rlouf commented 3 years ago

I would open a PR in JAX.