pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.12k stars 231 forks source link

Implement `RelaxedOneHotCategoricalStraightThrough` #559

Open rtbs-dev opened 4 years ago

rtbs-dev commented 4 years ago

Following #548 discussion, and while we wait for discrete latent variables, it would be nice to have a Gumbel-Softmax categorical approximation as featured in Pyro. Didn't realize this was the name given to Gumbel-Softmax in Pyro, but hopefully replication might be straight-forward?

numpyro (i.e. Jax) seems uniquely suited for problems involving large discrete structures (e.g. networks), so an ability to recover latent discrete variables (or their approximations) would be fantastic!

rtbs-dev commented 4 years ago

Pyro link for the original implementation.

daydreamt commented 4 years ago

If no one else commits this until then (and no hard feelings if they do), I will give this a stab starting next weekend.

neerajprad commented 4 years ago

@daydreamt - Please go ahead, you will be assured a thorough and timely review. If you have any questions around the codebase, please let us know.