pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.1k stars 227 forks source link

HMC Gibbs Example #1088

Open ross-h1 opened 3 years ago

ross-h1 commented 3 years ago

Hi all, I've been working with the HMC Gibbs sampler.

Thanks for your work on this, I think it's a really unique and powerful feature. I've written an example notebook (a ttached). This uses HMC Gibbs and HMC on the same hierarchical problem, and shows significant speed up with HMCGibbs as the dimension of the conditionally conjugate part of the problem increases. (vary the number of series n to see this). In my experiments HMCG scales more or less linearly with n, HMC is better on small problems, but orders of magnitude worse as dimension increases.

Numpyro HMC Gibbs.ipynb.zip

A few questions and queries...

  1. Some more examples and guidance / documentation on this would be super useful.
  2. Whats the idiomatic way of transferring values computed within the model function to the Gibbs function? - I've done this here using 'deterministic' statements, and this seems to work... though it feels a bit of a hack... undocumented functionality!
  3. Any further refinements (or errors) in the attached please let me know!
ross-h1 commented 3 years ago

a few corrections to this:

in the gibbs_fn, the scale needs a sqrt: betai = dist.Normal(beta_post_mean, jnp.sqrt(beta_post_var)).sample(rng_key)

and the data frame results should be: pd.DataFrame(np.asarray([[10,11],[58,30],[1920,55]]),index=[50,500,1000],columns=['hmc','hmc/gibbs'])

fehiepsi commented 3 years ago

Whats the idiomatic way of transferring values computed within the model function to the Gibbs function

I think using deterministic is intended here. Could you mention this usage in the docs for gibbs_fn?

You can find some examples in the tests but I think that there's nothing new to you. Given your experience with this class, it would be nice if you can contribute an example or tutorial for users. :)

(just a note that you can set group_by_chain=True in mcmc.get_samples() to add an additional singleton dimension to the samples)

ross-h1 commented 3 years ago

Thanks for hint with group by chain. Very happy to contribute based an example / tutorial on this notebookLet me know what else I need to do!

edwinnglabs commented 1 year ago

Is this tutorial / some similar tutorial for HMCGibbs now officially in the numpyro tutorial yet?

fehiepsi commented 1 year ago

Hi @edwinnglabs, we don't have a tutorial yet. But we have an example in the doc https://num.pyro.ai/en/stable/mcmc.html#numpyro.infer.hmc_gibbs.HMCGibbs. Hope that it's helpful.

martinjankowiak commented 1 year ago

@edwinnglabs you can also find examples in the tests