arnauqb / GumbelSoftmax.jl

Julia implementation of the Gumbel-Softmax reparametrization trick compatible with Zygote and ForwardDiff
MIT License
6 stars 0 forks source link

GumbelSoftmax

Build Status Coverage

This package implements:

  1. Gumbel-Softmax
  2. Rao-Blackwellized Gumbel-Softmax

in the Julia programming language. The package supports both forward-mode automatic differentiation (AD) and reverse-mode AD through Zygote and ForwardDiff respectively.

1. Installation

The easiest way is to get the package directly from the Julia repository

using Pkg
Pkg.add("GumbelSoftmax")

2. Usage

The expected input shape is (latent_dimension, categorical_dimension, batch_dimension). As an example, let's suppose we have 4 Categorical distributions with 3 classes and we want to sample 10 times. In this case, latent_dimension=4, categorical_dimension=3, and batch_dimension=10.

using GumbelSoftmax, Random
logits = randn(4, 3, 10)
samples = sample_gumbel_softmax(logits=logits, tau=0.1, hard=true)
# or with Rao-Blackwellization
k = 10 # number of Monte-Carlo samples
samples = sample_rao_gumbel_softmax(logits=logits, tau=0.1, k=k)

3. Example: Discrete Variational Autoencoder (VAE)

We include an example of using the Gumbel-Softmax trick to implement a discrete VAE. The example can be found in examples/vae.jl and it can be run with

julia examples/vae.jl

Here are some results:

VAE loss:

loss

VAE reconstructions and generated samples: