pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.15k stars 236 forks source link

Support scan for Trace_ELBO #1693

Closed deoxyribose closed 9 months ago

deoxyribose commented 10 months ago

Issue #1685

Reuse the substitute_stack to store the replay trace, and replay the sites in the scan'ed function one iteration at a time.

fehiepsi commented 10 months ago

Nice support, @deoxyribose! Could you add tests for this change?

deoxyribose commented 10 months ago

I've added a test that does SVI with an AutoNormal guide, and checks that results are relatively accurate. I've kept model and data size to a minimum, the test takes around 3-4 seconds on my machine. There already was a test combining scan, SVI and AutoNormal (test_subsample_guide in test/infer/test_autoguide.py), but it doesn't run inference and checks results, so the current version of scan passes it.