pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.12k stars 231 forks source link

pmap over `num_particles` in SVI #1645

Open snehjp2 opened 11 months ago

snehjp2 commented 11 months ago

Hi,

In Trace_ELBO, the num_particles argument allows one to effectively introduce a batch size in estimating the ELBO gradient if num_particles > 1. By default, it's vectorized over the num_particles. Is it possible to also distribute the batch dimension over devices (e.g. when running on multiple GPUs). My particular application is prone to jax OOM errors and would benefit from distribution over jax.pmap.

fehiepsi commented 11 months ago

If you got OOM, you can set vectorize particles to False. You can also use PositionalSharding like in MCMC I guess.

If you want to pmap over particles, could you make a PR for it? I think we can just simply allow a callable vectorize_particles and call it here.