pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.1k stars 227 forks source link

Add Initial SDVI Implementation #1758

Closed treigerm closed 4 months ago

treigerm commented 4 months ago

This is follow-up work from https://github.com/pyro-ppl/numpyro/pull/1715 and now adding Support Decomposition VI (SDVI) which is a variational method for programs with stochastic support as discussed in #1697.

Once this is merged I will add a simple tutorial on how to use these algorithms properly. At the moment, only the most basic version of DCC and SDVI are implemented, so the idea is that over time we can gradually add more bells and whistles (most prominently run inference in different program paths in parallel). For now, I wanted to keep the implementation simple to keep the PRs at a reasonable size.

fehiepsi commented 4 months ago

Thanks, @treigerm! The PR looks in a great shape. Do you want me to take an extra look at some details?

treigerm commented 4 months ago

Thanks @fehiepsi! There wasn't anything in particular that I think needed attention. So if there's nothing which you think sticks out as problematic, I am happy for you to merge!

fehiepsi commented 4 months ago

Yeah, your plan sounds good to me.