pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.12k stars 231 forks source link

[FR] make DiscreteHMCGibbs plate aware #1406

Open martinjankowiak opened 2 years ago

martinjankowiak commented 2 years ago

afaik DiscreteHMCGibbs does not make use of plate information when computing Gibbs updates for discrete latent variables. it would be nice to support this, as leveraging this information can make Gibbs steps exponentially cheaper

fehiepsi commented 2 years ago

This sounds possible. I think we can perform discrete update in blocks (maybe #898 is helpful) where in each block, we can decide which strategy we want (i.e. using current DiscreteHMCGibbs strategy or plate-aware strategy). For plate-aware, I guess we just need to maintain plate-information of each factor, remove unnecessary factors (based on provenance?), then perform the gibbs update in one pass.