Open renecotyfanboy opened 3 months ago
Hi! I believe nested sampling relies on some structure in the model, i.e., the ability to factor a joint probability distribution into a prior and a likelihood. I think you could do some really interesting things with bayeux
if I (we? you?) figured out some pleasant way of incorporating that into the library. In particular, some of the VI routines from Numpyro, or some of the SMC implementations in blackjax or TFP would become feasible.
I am not sure what that would look like! The input to bayeux is a Callable[PyTree[float], float]
(where I'm abusing notation to suggest a PyTree whose leaves are floats). If we call that a LogDensity
, I guess a "prior" would need to be some sort of PyTree[LogDensity]
, but maybe the log densities also need to know how to produce samples...
Anyways, this would need some thought or some design. I'm happy to review a pull request if you have a clear vision, or check out drafts (or you could make your own, more flexible library that supports structured inference!)
Hi, apologies for this late answer, I am a bit busy at the moment...
I believe nested sampling relies on some structure in the model, i.e., the ability to factor a joint probability distribution into a prior and a likelihood.
Even if the jaxns
model building requires a prior model and a likelihood, the numpyro
wrapper uses an identity function as the prior and uses the posterior log-probability as the likelihood, and this seems to do the trick! The code is a bit convoluted tho because of jaxns
requiring explicit signatures for the functions.
So for nested sampling, this is not a requirement to factor prior and likelihood from the log-prob function. I'll draft a PR eventually, just be patient, ahah
In particular, some of the VI routines from Numpyro [...] would become feasible.
Which VI routines are not usable in the current situation? Prior distribution is not readily factorable from a numpyro
model, I would be curious to see those cases
No problem! I'm mostly AFK for a week, but wanted to put some thoughts down:
That's interesting that it works well to use an identity function! You certainly understand the situation better than me, and a draft would be welcome - this library has not had a ton of contributors, so I'm sure of the automation may be rocky, but I'm happy to spend some time getting it to work if you can provide a starting point (and perhaps a colab of the function working?)
For VI, I'm not a heavy numpyro user, but I guess I was looking at guide generation (https://num.pyro.ai/en/stable/autoguide.html) and thinking it was a little silly to go from, say, a PyMC model to a PyTree to just guessing that every parameter is a Normal. Maybe it would work better than I expect, though! (in particular, I guess this is mean field VI?)
Hi there, Thank you very much for putting this package together, this is impressive! I was wondering if you would be interested in an implementation of nested sampling in pure jax. I know that the
jaxns
package provides an implementation of the Phantom Powered nested sampling algorithm. I think it would be a nice addition to your collection, and there is already a compatibility layer withnumpyro
.If you are interested, I can try to draft an implementation of this, even though I would probably wrap the
numpyro
contributed code instead of working directly withjaxns
. WDYT?