amir-abdi / disentanglement-pytorch

Disentanglement library for PyTorch
GNU General Public License v3.0
273 stars 30 forks source link

CircleCI

Disentanglement-PyTorch

Pytorch Implementation of Disentanglement algorithms for Variational Autoencoders. This library was developed as a contribution to the Disentanglement Challenge of NeurIPS 2019.

If the library helped your research, consider citing the corresponding submission of the NeurIPS 2019 Disentanglement Challenge:

@article{abdiDisentanglementPytorch,
    Author = {Amir H. Abdi and Purang Abolmaesumi and Sidney Fels},
    Title = {Variational Learning with Disentanglement-PyTorch},
    Year = {2019},
    journal={arXiv preprint arXiv:1912.05184},    
}

The following algorithms are implemented:

Note: Everything is modular, you can mix and match neural architectures and algorithms. Also, multiple loss terms can be included in the --loss_terms argument, each with their respective weights. This enables us to combine a set of disentanglement algorithms for representation learning.

Requirements and Installation

Install the requirements: pip install -r requirements.txt \ Or build conda environment: conda env create -f environment.yml

The library visualizes the reconstructed images and the traversed latent spaces and saves them as static frames as well as animated GIFs. It also extensively uses the web-based Weights & Biases toolkit for logging and visualization purposes.

Training

python main.py [[--ARG ARG_VALUE] ...]

or

bash scripts/SCRIPT_NAME

Flags and Configs

For the complete list of arguments, please check the source.

Data Setup

To run the scripts:

1- Set the -dset_dir flag or the $DISENTANGLEMENT_LIB_DATA environment variable to the directory holding all the datasets (the former is given priority).

2- Set the dset_name flag or the $DATASET_NAME environment variable to the name of the dataset (the former is given priority). The supported datasets are: celebA, dsprites (and the Deppmind's variants: color, noisy, scream, introduced here), smallnorb, cars3d, mpi3d_toy, and mpi3d_realistic, and mpi3d_real.

Please check the repository for the mpi3d datasets for license agreements and consider citing their work.

Currently, there are two dataloaders in place:

NeurIPS 2019 Disentanglement Challenge

To use this code in the NeurIPS 2019 Disentanglement Challenge

Sample Results

Method Latent traversal visualization
VAE
FactorVAE
CVAE (conditioned on shape)
Right-most item is traversing the condition
IFCVAE (factorized on shape)
Right-most factor is enforced to encode the shape
BetaTCVAE
VAE

Contributions

Any contributions, especially around implementing more disentanglement algorithms, are welcome. Feel free to submit bugs, feature requests, or questions as issues, or contact me directly via email at: amirabdi@ece.ubc.ca