Open miguelsvasco opened 4 years ago
@miguelsvasco Hi, thank you for your detailed comments and I'm sorry for my late reply.
However, the original formulation of the MVAE model (in the paper Multimodal Generative Models for Scalable Weakly-Supervised Learning), does not consider such terms, only a KL divergence term between the distribution of the POE encoder and the prior
Yes. this loss function comes from not MIVAE
but JMVAE
(originally proposed in this paper as JMVAE-kl
). Though the PoE encoder is not used in the original paper of JMVAE, we wanted to see if this PoE encoder works well on the JMVAE loss. Anyway, I'm sorry for the confusion.
When I remove the kl_x and kl_y terms from the regularizer and train, the model seems unable to perform cross-modality inference:
This might be due to not training "unimodal" inferences of the PoE encoder, q(z|x) and q(z|y). Without it, inferred z from unimodal input (especially label or attribute) might be collapsed (a similar issue is also referred to our preprint paper as the "missing modality difficulty"). In the JMVAE, these are trained by making close them to "bimodal" inference q(z|x,y), which corresponds to the additional KL terms you pointed out.
Would that be possible to implement with the Pixyz framework?
Yes, but you should use the Model
class instead of the VAE
class because the loss function becomes more complex.
The implementation of the original MVAE model with Pixyz is as follows.
Given your comments, I replaced the name of the previous notebook from mvae_poe.ipynb
to jmvae_poe.ipynb
(to avoid confusion), and added the new notebook mvae.ipynb
which includes the implementation of the original MVAE model.
Thank you!
@masa-su Thank you for the framework. For the MVAE implementation you provided above, how the model should be trained for the semi-supervised case? Let's say for the MNIST dataset only a share of labels is available. Should two Model objects which share the networks but have different loss functions be created for 1) the image and the label available and 2) only the label available?
@sgalkina
Thank you for your comment!
I don't know what kind of loss functions for each supervised and unsupervised you are going to implement, but you can use the replace_var
method in the Distribution class to share the same network in different losses, e.g., supervised and unsupervised losses.
For an example of the usage, please see the implementation of the M2 model, which is the well-known semi-supervised VAE model.
If you have any trouble understanding how to use it, please feel free to ask!
First of all thank you for the code! Just two slight remark:
However, the original formulation of the MVAE model (in the paper Multimodal Generative Models for Scalable Weakly-Supervised Learning), does not consider such terms, only a KL divergence term between the distribution of the POE encoder and the prior:
When I remove the kl_x and kl_y terms from the regularizer and train, the model seems unable to perform cross-modality inference: