pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.13k stars 232 forks source link

Add a utility to get posterior samples of discrete latent variables #770

Closed fehiepsi closed 3 years ago

fehiepsi commented 3 years ago

Per @martinjankowiak suggestion, after we marginalize discrete latent variables to run HMC and get posterior for continuous variables, it would be nice to have a utility to get posterior samples for discrete latent variables. That is we have a model with p(data | discrete, continuous), we marginalize and run MCMC to get p(continuous | data). This utility will get samples from p(discrete | data, continuous).

References:

fritzo commented 3 years ago

This looks like another edifying application of funsors 🙂 . You may be able to simply call .sample() on the final funsor in enumeration, then unpack the resulting Delta funsors.

fehiepsi commented 3 years ago

This is achieved through infer_discrete.