masa-su / pixyz

A library for developing deep generative models in a more concise, intuitive and extendable way
https://pixyz.io
MIT License
484 stars 41 forks source link

Implementation of the MVAE model #80

Open miguelsvasco opened 4 years ago

miguelsvasco commented 4 years ago

First of all thank you for the code! Just two slight remark:

dkl

However, the original formulation of the MVAE model (in the paper Multimodal Generative Models for Scalable Weakly-Supervised Learning), does not consider such terms, only a KL divergence term between the distribution of the POE encoder and the prior:

elbo

When I remove the kl_x and kl_y terms from the regularizer and train, the model seems unable to perform cross-modality inference:

cmi

masa-su commented 4 years ago

@miguelsvasco Hi, thank you for your detailed comments and I'm sorry for my late reply.

However, the original formulation of the MVAE model (in the paper Multimodal Generative Models for Scalable Weakly-Supervised Learning), does not consider such terms, only a KL divergence term between the distribution of the POE encoder and the prior

Yes. this loss function comes from not MIVAE but JMVAE (originally proposed in this paper as JMVAE-kl). Though the PoE encoder is not used in the original paper of JMVAE, we wanted to see if this PoE encoder works well on the JMVAE loss. Anyway, I'm sorry for the confusion.

When I remove the kl_x and kl_y terms from the regularizer and train, the model seems unable to perform cross-modality inference:

This might be due to not training "unimodal" inferences of the PoE encoder, q(z|x) and q(z|y). Without it, inferred z from unimodal input (especially label or attribute) might be collapsed (a similar issue is also referred to our preprint paper as the "missing modality difficulty"). In the JMVAE, these are trained by making close them to "bimodal" inference q(z|x,y), which corresponds to the additional KL terms you pointed out.

Would that be possible to implement with the Pixyz framework?

Yes, but you should use the Model class instead of the VAE class because the loss function becomes more complex. The implementation of the original MVAE model with Pixyz is as follows.

スクリーンショット 2019-08-06 13 19 35

Given your comments, I replaced the name of the previous notebook from mvae_poe.ipynb to jmvae_poe.ipynb (to avoid confusion), and added the new notebook mvae.ipynb which includes the implementation of the original MVAE model.

Thank you!

sgalkina commented 4 years ago

@masa-su Thank you for the framework. For the MVAE implementation you provided above, how the model should be trained for the semi-supervised case? Let's say for the MNIST dataset only a share of labels is available. Should two Model objects which share the networks but have different loss functions be created for 1) the image and the label available and 2) only the label available?

masa-su commented 4 years ago

@sgalkina Thank you for your comment! I don't know what kind of loss functions for each supervised and unsupervised you are going to implement, but you can use the replace_var method in the Distribution class to share the same network in different losses, e.g., supervised and unsupervised losses.

For an example of the usage, please see the implementation of the M2 model, which is the well-known semi-supervised VAE model.

If you have any trouble understanding how to use it, please feel free to ask!