Pytorch Implementation of Disentanglement algorithms for Variational Autoencoders. This library was developed as a contribution to the Disentanglement Challenge of NeurIPS 2019.
If the library helped your research, consider citing the corresponding submission of the NeurIPS 2019 Disentanglement Challenge:
@article{abdiDisentanglementPytorch,
Author = {Amir H. Abdi and Purang Abolmaesumi and Sidney Fels},
Title = {Variational Learning with Disentanglement-PyTorch},
Year = {2019},
journal={arXiv preprint arXiv:1912.05184},
}
The following algorithms are implemented:
Note: Everything is modular, you can mix and match neural architectures and algorithms.
Also, multiple loss terms can be included in the --loss_terms
argument, each with their respective
weights. This enables us to combine a set of disentanglement algorithms for representation learning.
Install the requirements: pip install -r requirements.txt
\
Or build conda environment: conda env create -f environment.yml
The library visualizes the reconstructed images and the traversed latent spaces and saves them as static frames as well as animated GIFs. It also extensively uses the web-based Weights & Biases toolkit for logging and visualization purposes.
python main.py [[--ARG ARG_VALUE] ...]
or
bash scripts/SCRIPT_NAME
--alg
: The main formulation for training. \
Values:
AE (AutoEncoder),
VAE (Variational AutoEncoder),
BetaVAE,
CVAE (Conditional VAE),
IFCVAE (Information Factorization CVAE)
--loss_terms
: Extensions to the VAE algorithm
are implemented as plug-ins to the original forumation.
As a result, if the loss terms of two learning algorithms (e.g., A and B)
were found to be compatible, they can simultaneously be included in the objective
function with the flag set as --loss_terms A B
.
The loss_terms
flag can be used with VAE, BetaVAE, CVAE, and
IFCVAE algorithms. \
Values: FACTORVAE, DIPVAEI, DIPVAEII, BetaTCVAE, INFOVAE
--evaluation_metric
: Metric(s) to use for disentanglement evaluation (see scripts/aicrowd_challenge
). \
Values: mig, sap_score, irs, factor_vae_metric, dci, beta_vae_sklearn
For the complete list of arguments, please check the source.
To run the scripts:
1- Set the -dset_dir
flag or the $DISENTANGLEMENT_LIB_DATA
environment variable to the directory
holding all the datasets (the former is given priority).
2- Set the dset_name
flag or the $DATASET_NAME
environment variable to the name of the dataset (the former is given priority).
The supported datasets are:
celebA,
dsprites
(and the Deppmind's variants: color, noisy, scream, introduced here),
smallnorb,
cars3d,
mpi3d_toy, and
mpi3d_realistic, and
mpi3d_real.
Please check the repository for the mpi3d datasets for license agreements and consider citing their work.
Currently, there are two dataloaders in place:
To use this code in the NeurIPS 2019 Disentanglement Challenge
source train_environ.sh NAME_OF_DATASET_TO_TEST
--aicrowd_challenge=true
in your bash file--evaluate_metric mig sap_score irs factor_vae_metric dci
to assess the progression of disentanglement metrics during training. run.sh
to your highest performing configuration.Method | Latent traversal visualization |
---|---|
VAE | |
FactorVAE | |
CVAE (conditioned on shape) | Right-most item is traversing the condition |
IFCVAE (factorized on shape) | Right-most factor is enforced to encode the shape |
BetaTCVAE | |
VAE |
Any contributions, especially around implementing more disentanglement algorithms, are welcome. Feel free to submit bugs, feature requests, or questions as issues, or contact me directly via email at: amirabdi@ece.ubc.ca