vithursant / VAE-Gumbel-Softmax

An implementation of a Variational-Autoencoder using the Gumbel-Softmax reparametrization trick in TensorFlow (tested on r1.5 CPU and GPU) in ICLR 2017.
Apache License 2.0
73 stars 18 forks source link
deeplearning gumbel-softmax mnist tensorflow vae variational-autoencoder

VAE with Gumbel-Softmax

TensorFlow implementation of a Variational Autoencoder with Gumbel-Softmax Distribution. Refer to the following paper:

Also, included is a jupyter notebook which shows how the Gumbel-Max trick for sampling discrete variables relates to Concrete distributions.

Table of Contents


The program requires the following dependencies (easy to install using pip, Ananconda or Docker):


Anaconda: CPU Installation

To install VAE-Gumbel-Softmax in an TensorFlow 1.5 CPU - Python 2.7 environment:

conda env create -f tf_py26_cpu_env.yml

To activate Anaconda environment:

source activate tf-py26-cpu-env

Anaconda: GPU Installation

To install VAE-Gumbel-Softmax in an TensorFlow 1.5 GPU - Python 3.5 environment:

conda env create -f tf_py35_gpu_env.yml

To activate Anaconda environment:

source activate tf-py35-gpu-env

Anaconda: Train

Train VAE-Gumbel-Softmax model on the local machine using MNIST dataset:



Train VAE-Gumbel-Softmax model using Docker on the MNIST dataset:

docker build -t vae-gs .
docker run vae-gs

Note: Current Dockerfile is for TensorFlow 1.5 CPU training.



Batch Size:                         100
Number of Iterations:               50000
Learning Rate:                      0.001
Initial Temperature:                1.0
Minimum Temperature:                0.5
Anneal Rate:                        0.00003
Straight-Through Gumbel-Softmax:    False
KL-divergence:                      Relaxed
Learnable Temperature:              False


Ground Truth Reconstructions

Citing VAE-Gumbel-Softmax

If you use VAE-Gumbel-Softmax in a scientific publication, I would appreciate references to the source code.

Biblatex entry:

  author = {Thangarasa, Vithursan},
  title = {VAE-Gumbel-Softmax},
  year = {2017},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{}}