pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.24k stars 243 forks source link

Bayesian imputation tutorial with discrete covariates #726

Closed vanAmsterdam closed 4 years ago

vanAmsterdam commented 4 years ago

since numpyro supports enumerating discrete latent variables, imputing missing values for discrete covariates should be a possibility (which makes numpyro suitable for many more applied projects!)

Since array shapes will be altered when using parallel enumeration it is not directly evident how to adapt the continuous imputation example to discrete covariates, an example may be helpful

fehiepsi commented 4 years ago

I think it is really nice to have a different tutorial for this. For those who want to contribute one, here is an approach, which is quite similar to the approach in Bayesian imputation tutorial.

x = sample('x', dist.Categorical(probs).mask(False))
log_prob = dist.Categorical(probs).log_prob(x)
# mask out values which are different from x_obs
log_prob[(x_obs != nan) & (x != x_obs)] = -inf
numpyro.factor('x_obs', log_prob)

For those who are interested in, please ping me for any question that you have.

vanAmsterdam commented 4 years ago

I've created a pull request with an initial version (#730). In addition to showcasing this method for automatically enumerating missing covariates, it discusses several forms of missing data and how to handle them

fehiepsi commented 4 years ago

Thanks for addressing this issue, @vanAmsterdam!