marius311 / muse_inference

9 stars 3 forks source link

Contributing MUSE to BlackJAX #1

Open junpenglao opened 2 years ago

junpenglao commented 2 years ago

Hi @marius311, Great work on MUSE! I am wondering if you are interested in contributing MUSE to BlackJAX.

Some benefits:

Best, Junpeng

marius311 commented 2 years ago

Thanks for the interest! I agree those reasons are all highly compelling and I'm definitely interested, but given limited time right now I think I have to prioritize finishing up this package which will include just Jax and PyMC at first, and some other projects. But potentially once the Jax part of this is done, blackjax is pretty trivial? I haven't looked too deep into that package, but would part of this mean writing the MUSE solver itself in a way which is jit-able by Jax? (the solver here uses the jit-ed posterior and gradients, but may not itself be jit-able I haven't checked)

junpenglao commented 2 years ago

I think the code change should be pretty minimal, what makes an algorithm "integrated" to BlackJAX is how the high level user facing API. For example, for MCMC, user supply the log_prob func that takes parameters (we are working on SGMCMC that takes parameters and (minibatch) observed). For MUSE, most of the work will be basically around how to make it easier for user to supply these components (theta_logp, sample_x_z, etc)