pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.13k stars 232 forks source link

[ENH, DOCS] Add tutorial on how to deal with discrete latent variables #1264

Closed corneliusroemer closed 2 years ago

corneliusroemer commented 2 years ago

Numpyro is great! But it took me a while to figure out how to work with discrete latent variables. Unfortunately, there's not much in the documentation about it yet.

It would probably be good to condense the various discussions in issues, gists and forum posts into a doc page.

These are the resources that helped me - if anyone stumbles upon this issue, here you go ;) Arviz workaround: #1121 How to get discrete samples: https://github.com/pyro-ppl/numpyro/issues/770 Tutorial: https://gist.github.com/peterroelants/277d9d47a76a55c23a433d37bbbd6dd7/732b2af0881f96a0175a5f79ca556f3a48996e7a Example in doc (not dedicated to discrete latent variable): https://num.pyro.ai/en/stable/examples/annotation.html

fehiepsi commented 2 years ago

not dedicated to discrete latent variable

I guess your request is to have a tutorial that works with arviz? Currently, we have 5 examples/tutorials for models with discrete latent variables that cover various scenario: mixture model, hidden markov model, time series, imputation, non-enumerate (with discrete hmc gibbs in the nested sampling example). I think we can add some codes to the annotation example to illustrate that usage case (I guess it might be better to have such example/tutorial in arviz but not many frameworks support inference with discrete latent variables through marginalization). Do you want to contribute? :)

aflaxman commented 2 years ago

100% agree that numpyro is great! I've been using pymc2 long after it has been deprecated, and I'm think about switching to numpyro.

If you do write a documentation on this, can you help me understand why plates are needed? I.e. why doesn't something like this version work?

def gmm_model_no_plate(data, k):
    selection_prob = numpyro.sample('selection_prob', dist.Dirichlet(concentration=jnp.ones(k)))

    mu = numpyro.sample('mu', dist.Normal(loc=jnp.zeros(k), scale=10.*jnp.ones(k)))
    sigma = numpyro.sample('scale', dist.HalfCauchy(scale=10*jnp.ones(k)))

    cluster_idx = numpyro.sample('cluster_idx', dist.Categorical(selection_prob))
    numpyro.sample('x', dist.Normal(loc=mu[cluster_idx], scale=sigma[cluster_idx]), obs=data)
fehiepsi commented 2 years ago

Hi @aflaxman, that version will work with algorithms like DiscreteHMCGibbs or MixedHMC. But for algorithms that require enumeration, at each site we will need to specify which dimension we need to marginalize over. For example, for each value of cluster_idx, we have a corresponding likelihood of x. Under enumeration, x will have an additional cluster_idx, and we will need to specify which dimension of x that we want to marginalize (i.e. taking integration).

plate is useful for many other usage cases, to name a few:

aflaxman commented 2 years ago

Super helpful reply, much appreciated!