This repository contains the code for the framework in Variational Mixture-of-Experts Autoencodersfor Multi-Modal Deep Generative Models (see paper).
List of packages we used and the version we tested the model on (see also requirements.txt
)
python == 3.6.8
gensim == 3.8.1
matplotlib == 3.1.1
nltk == 3.4.5
numpy == 1.16.4
pandas == 0.25.3
scipy == 1.3.2
seaborn == 0.9.0
scikit-image == 0.15.0
torch == 1.3.1
torchnet == 0.0.4
torchvision == 0.4.2
umap-learn == 0.1.1
We construct a dataset of pairs of MNIST and SVHN such that each pair depicts the same digit class. Each instance of a digit class in either dataset is randomly paired with 20 instances of the same digit class from the other dataset.
Usage: To prepare this dataset, run bin/make-mnist-svhn-idx.py
-- this should automatically handle the download and pairing.
We use Caltech-UCSD Birds (CUB) dataset, with the bird images and their captions serving as two modalities.
Usage: We offer a cleaned-up version of the CUB dataset. Download the dataset here. First, create a data
folder under the project directory; then unzip thedownloaded content into data
. After finishing these steps, the structure of the data/cub
folder should look like:
data/cub
│───text_testclasses.txt
│───text_trainvalclasses.txt
│───train
│ │───002.Laysan_Albatross
│ │ └───...jpg
│ │───003.Sooty_Albatross
│ │ └───...jpg
│ │───...
│ └───200.Common_Yellowthroat
│ └───...jpg
└───test
│───001.Black_footed_Albatross
│ └───...jpg
│───004.Groove_billed_Ani
│ └───...jpg
│───...
└───197.Marsh_Wren
└───...jpg
Pretrained models are also available if you want to play around with it. Download from the following links:
Make sure the requirements are satisfied in your environment, and relevant datasets are downloaded. cd
into src
, and, for MNIST-SVHN experiments, run
python main.py --model mnist_svhn
For CUB Image-Caption with image feature search (See Figure 7 in our paper), run
python main.py --model cubISft
For CUB Image-Caption with raw image generation, run
python main.py --model cubIS
You can also play with the hyperparameters using arguments. Some of the more interesting ones are listed as follows:
--obj
: Objective functions, offers 3 choices including importance-sampled ELBO (elbo
), IWAE (iwae
) and DReG (dreg
, used in paper). Including the --looser
flag when using IWAE or DReG removes unbalanced weighting of modalities, which we find to perform better empirically;--K
: Number of particles, controls the number of particles K
in IWAE/DReG estimator, as specified in following equation:--learn-prior
: Prior variance learning, controls whether to enable prior variance learning. Results in our paper are produced with this enabled. Excluding this argument in the command will disable this option;--llik_scaling
: Likelihood scaling, specifies the likelihood scaling of one of the two modalities, so that the likelihoods of two modalities contribute similarly to the lower bound. The default values are:
--latent-dimension
: Latent dimensionYou can also load from pre-trained models by specifying the path to the model folder, for example python --model mnist_svhn --pre-trained path/to/model/folder/
. See following for the flag we used for these pretrained models:
--model mnist_svhn --obj dreg --K 30 --learn-prior --looser --epochs 30 --batch-size 128 --latent-dim 20
--model cubISft --learn-prior --K 50 --obj dreg --looser --epochs 50 --batch-size 64 --latent-dim 64 --llik_scaling 0.002
--model cubIS --learn-prior --K 50 --obj dreg --looser --epochs 50 --batch-size 64 --latent-dim 64
We offer tools to reproduce the quantitative results in our paper in src/report
. To run any of the provided scripts, cd
into src
, and
python calculate_likelihoods.py --save-dir path/to/trained/model/folder/ --iwae-samples 1000
;python analyse_ms.py --save-dir path/to/trained/model/folder/
;python analyse_cub.py --save-dir path/to/trained/model/folder/
.
cub.all
, cub.emb
, cub.pc
to under data/cub/oc:3_sl:32_s:300_w:3/
;emb_mean.pt
, emb_proj.pt
, images_mean.pt
, im_proj.pt
to path/to/trained/model/folder/
;RESET
variable in src/report/analyse_cub.py
to False
.If you have any questions, feel free to create an issue or email Yuge Shi at yshi@robots.ox.ac.uk.