jax-ml / bayeux

State of the art inference for your bayesian models.
https://jax-ml.github.io/bayeux/
Apache License 2.0
177 stars 8 forks source link

Implementing nested sampling #55

Open renecotyfanboy opened 3 months ago

renecotyfanboy commented 3 months ago

Hi there, Thank you very much for putting this package together, this is impressive! I was wondering if you would be interested in an implementation of nested sampling in pure jax. I know that the jaxns package provides an implementation of the Phantom Powered nested sampling algorithm. I think it would be a nice addition to your collection, and there is already a compatibility layer with numpyro.

If you are interested, I can try to draft an implementation of this, even though I would probably wrap the numpyro contributed code instead of working directly with jaxns. WDYT?

ColCarroll commented 3 months ago

Hi! I believe nested sampling relies on some structure in the model, i.e., the ability to factor a joint probability distribution into a prior and a likelihood. I think you could do some really interesting things with bayeux if I (we? you?) figured out some pleasant way of incorporating that into the library. In particular, some of the VI routines from Numpyro, or some of the SMC implementations in blackjax or TFP would become feasible.

I am not sure what that would look like! The input to bayeux is a Callable[PyTree[float], float] (where I'm abusing notation to suggest a PyTree whose leaves are floats). If we call that a LogDensity, I guess a "prior" would need to be some sort of PyTree[LogDensity], but maybe the log densities also need to know how to produce samples...

Anyways, this would need some thought or some design. I'm happy to review a pull request if you have a clear vision, or check out drafts (or you could make your own, more flexible library that supports structured inference!)

renecotyfanboy commented 3 months ago

Hi, apologies for this late answer, I am a bit busy at the moment...

I believe nested sampling relies on some structure in the model, i.e., the ability to factor a joint probability distribution into a prior and a likelihood.

Even if the jaxns model building requires a prior model and a likelihood, the numpyro wrapper uses an identity function as the prior and uses the posterior log-probability as the likelihood, and this seems to do the trick! The code is a bit convoluted tho because of jaxns requiring explicit signatures for the functions.

So for nested sampling, this is not a requirement to factor prior and likelihood from the log-prob function. I'll draft a PR eventually, just be patient, ahah

In particular, some of the VI routines from Numpyro [...] would become feasible.

Which VI routines are not usable in the current situation? Prior distribution is not readily factorable from a numpyro model, I would be curious to see those cases

ColCarroll commented 2 months ago

No problem! I'm mostly AFK for a week, but wanted to put some thoughts down:

That's interesting that it works well to use an identity function! You certainly understand the situation better than me, and a draft would be welcome - this library has not had a ton of contributors, so I'm sure of the automation may be rocky, but I'm happy to spend some time getting it to work if you can provide a starting point (and perhaps a colab of the function working?)

For VI, I'm not a heavy numpyro user, but I guess I was looking at guide generation (https://num.pyro.ai/en/stable/autoguide.html) and thinking it was a little silly to go from, say, a PyMC model to a PyTree to just guessing that every parameter is a Normal. Maybe it would work better than I expect, though! (in particular, I guess this is mean field VI?)