This tutorial implements Learning Structured Output Representation using Deep Conditional Generative Models paper, which introduced Conditional Variational Auto-encoders in 2015, using Pyro PPL.
Supervised deep learning has been successfully applied for many recognition problems in machine learning and computer vision. Although it can approximate a complex many-to-one function very well when large number of training data is provided, the lack of probabilistic inference of the current supervised deep learning methods makes it difficult to model a complex structured output representations. In this work, Kihyuk Sohn, Honglak Lee and Xinchen Yan develop a scalable deep conditional generative model for structured output variables using Gaussian latent variables. The model is trained efficiently in the framework of stochastic gradient variational Bayes, and allows a fast prediction using stochastic feed-forward inference. They called the model Conditional Variational Auto-encoder (CVAE).
The CVAE is a conditional directed graphical model whose input observations modulate the prior on Gaussian latent variables that generate the outputs. It is trained to maximize the conditional marginal log-likelihood. The authors formulate the variational learning objective of the CVAE in the framework of stochastic gradient variational Bayes (SGVB). In experiments, they demonstrate the effectiveness of the CVAE in comparison to the deterministic neural network counterparts in generating diverse but realistic output predictions using stochastic inference. Here, we will implement their proof of concept: an artificial experimental setting for structured output prediction using MNIST database.
Let's divide each digit image into four quadrants, and take one, two, or three quadrant(s) as an input and the remaining quadrants as an output to be predicted. The image below shows the case where one quadrant is the input:
Our objective is to learn a model that can perform probabilistic inference and make diverse predictions from a single input. This is because we are not simply modeling a many-to-one function as in classification tasks, but we may need to model a mapping from single input to many possible outputs. One of the limitations of deterministic neural networks is that they generate only a single prediction. In the example above, the input shows a small part of a digit that might be a three or a five.
For qualitative analysis, we visualize the generated output samples in the next figure. As we can see, the baseline NNs can only make a single deterministic prediction, and as a result the output looks blurry and doesn’t look realistic in many cases. In contrast, the samples generated by the CVAE models are more realistic and diverse in shape; sometimes they can even change their identity (digit labels), such as from 3 to 5 or from 4 to 9, and vice versa.
We also provide a quantitative evidence by estimating the marginal conditional log-likelihoods (CLLs) in next table.
1 quadrant | 2 quadrants | 3 quadrants | |
---|---|---|---|
NN (baseline) | 100.4 | 61.9 | 25.4 |
CVAE (Monte Carlo) | 71.8 | 51.0 | 24.2 |
Performance gap | 28.6 | 10.9 | 1.2 |
We achieved similar results to the ones achieved by the authors in the paper. We trained only for 50 epochs with early stopping patience of 3 epochs; to improve the results, we could leave the algorithm training for longer. Nevertheless, we can observe the same effect shown in the paper: the estimated CLLs of the CVAE significantly outperforms the baseline NN.
See the full code on Github.
There are some issue reports when trying to run the code with Pyro versions different than the one in requirements.txt
.
So, to make sure the code works, the recommended way is to create a clean virtual environment (conda or virtualenv), and running pip install -r requirements.txt
in this new environment.
[1] Learning Structured Output Representation using Deep Conditional Generative Models
,
Kihyuk Sohn, Xinchen Yan, Honglak Lee