CDCgov / PyRenew

Python package for multi-signal Bayesian renewal modeling with JAX and NumPyro.
https://cdcgov.github.io/PyRenew/
Apache License 2.0
14 stars 2 forks source link

Should pyrenew's default MCMC sampler backend be numpyro's bundled NUTS sampler? #466

Open dylanhmorris opened 2 weeks ago

dylanhmorris commented 2 weeks ago

Or should the standard pattern be to specify the model in numpyro but pass it off to bayeux or blackjax etc for inference?

@CDCgov/pyrenew-devs

See also #361 #151

damonbayer commented 2 weeks ago

I've been scared away from bayeux a bit. Seems like it is not exactly plug and play with numpyro: https://github.com/jax-ml/bayeux/issues/51

AFg6K7h4fhy2 commented 2 weeks ago

My preference stands with numpyro over blackjax. Not sure where I am with bayeux given that I have not done much with it. Having other JAX based sampling options available seems good, but for the default I presently think deferring to numpyro seems best (this could ofc change if I learned bayeux was superior in some manner).

seabbs commented 2 weeks ago

I did a bunch of digging into within chain parallelization. My current takeaway is that jax can do it and maybe some of the other backends can but numpyro can't. If that actually is the case it puts more weight on being backend agnostic at least sometime in the future IMO.

I really can't see how or why that would be the case though so very keen to see evidence to the contrary from people who have actually fit a numpyro model and so might be better informed.