Rayhane-mamah / Efficient-VDVAE

Official Pytorch and JAX implementation of "Efficient-VDVAE: Less is more"
https://arxiv.org/abs/2203.13751
MIT License
188 stars 21 forks source link
image-generation jax paper python pytorch representation-learning unsupervised-learning vae

The Official Pytorch and JAX implementation of "Efficient-VDVAE: Less is more" Arxiv preprint

Louay Hazami·Rayhane Mama·Ragavan Thurairatnam




MIT license PWC PWC PWC PWC PWC PWC PWC PWC

Efficient-VDVAE is a memory and compute efficient very deep hierarchical VAE. It converges faster and is more stable than current hierarchical VAE models. It also achieves SOTA likelihood-based performance on several image datasets.

Pre-trained model checkpoints

We provide checkpoints of pre-trained models on MNIST, CIFAR-10, Imagenet 32x32, Imagenet 64x64, CelebA 64x64, CelebAHQ 256x256 (5-bits and 8-bits), FFHQ 256x256 (5-bits and 8bits), CelebAHQ 1024x1024 and FFHQ 1024x1024 in the links in the table below. All provided models are the ones trained for table 4 of the paper.

Dataset Pytorch JAX Negative ELBO
Logs Checkpoints Logs Checkpoints
MNIST link link link link 79.09 nats
CIFAR-10 Queued Queued link link 2.87 bits/dim
Imagenet 32x32 link link link link 3.58 bits/dim
Imagenet 64x64 link link link link 3.30 bits/dim
CelebA 64x64 link link link link 1.83 bits/dim
CelebAHQ 256x256 (5-bits) link link link link 0.51 bits/dim
CelebAHQ 256x256 (8-bits) link link link link 1.35 bits/dim
FFHQ 256x256 (5-bits) link link link link 0.53 bits/dim
FFHQ 256x256 (8-bits) link link link link 2.17 bits/dim
CelebAHQ 1024x1024 link link link link 1.01 bits/dim
FFHQ 1024x1024 link link link link 2.30 bits/dim

Notes:

Pre-requisites

To run this codebase, you need:

We recommend running all the code below inside a Linux screen or any other terminal multiplexer, since some commands can take hours/days to finish and you don't want them to die when you close your terminal.

Note:

Installation

To create the docker image used in both the Pytorch and JAX implementations:

cd build  
docker build -t efficient_vdvae_image .  

Note:

All code executions should be done within a docker container. To start the docker container, we provide a utility script:

sh docker_run.sh  # Starts the container and attaches terminal
cd /workspace/Efficient-VDVAE  # Inside docker container

Setup datasets

All datasets can be automatically downloaded and pre-processed from the convenience script we provide:

cd data_scripts
sh download_and_preprocess.sh <dataset_name>

Notes:

Setting the hyper-parameters

In this repository, we use hparams library (already included in the Dockerfile) for hyper-parameter management:

We highly recommend having a deeper look into how this library works by reading the hparams library documentation, the parameters description and figures 4 and 5 in the paper before trying to run Efficient-VDVAE.

We have heavily tested the robustness and stability of our approach, so changing the model/optimization hyper-parameters for memory load reduction should not introduce any drastic instabilities as to make the model untrainable. That is of course as long as the changes don't negate the important stability points we describe in the paper.

Training the Efficient-VDVAE

To run Efficient-VDVAE in Torch:

cd efficient_vdvae_torch  
# Set the hyper-parameters in "hparams.cfg" file  
# Set "NUM_GPUS_PER_NODE" in "train.sh" file  
sh train.sh  

To run Efficient-VDVAE in JAX:

cd efficient_vdvae_jax  
# Set the hyper-parameters in "hparams.cfg" file  
python train.py  

If you want to run the model with less GPUs than available on the hardware, for example 2 GPUs out of 8:

CUDA_VISIBLE_DEVICES=0,1 sh train.sh  # For torch  
CUDA_VISIBLE_DEVICES=0,1 python train.py  # For JAX  

Models automatically create checkpoints during training. To resume a model from its last checkpoint, set its <run.name> in hparams.cfg file and re-run the same training commands.

Since training commands will save the hparams of the defined run in the .cfg file. If trying to restart a pre-existing run (by re-using its name in hparams.cfg), we provide a convenience script for resetting saved runs:

cd efficient_vdvae_torch  # or cd efficient_vdvae_jax  
sh reset.sh <run.name>  # <run.name> is the first field in hparams.cfg  

Note:

Monitoring the training process

While writing this codebase, we put extra emphasis on verbosity and logging. Aside from the printed logs on terminal (during training), you can monitor the training progress and keep track of useful metrics using Tensorboard:

# While outside efficient_vdvae_torch or efficient_vdvae_jax  
# Run outside the docker container
tensorboard --logdir . --port <port_id> --reload_multifile True  

In the browser, navigate to localhost:<port_id> to visualize all saved metrics.

If Tensorboard is not installed (outside the docker container):

pip install --upgrade tensorboard

Inference with the Efficient-VDVAE

Efficient-VDVAE support multiple inference modes:

To run the inference:

cd efficient_vdvae_torch  # or cd efficient_vdvae_jax  
# Set the inference mode in "logs-<run.name>/hparams-<run.name>.cfg"  
# Set the same <run.name> in "hparams.cfg"  
python synthesize.py  

Notes:

Using custom datasets

If you want to train the networks on your custom datasets, you need the following requisites:

To use your custom dataset (in both training and inference), you only need to modify the data section of your hparams.cfg file. Specifically set dataset_source = custom then change the data paths and image metadata.

For an example custom dataset of resolution 512 and grey scale, the data section of my hparams.cfg would look like:

[data]
# Data section: Defines the dataset parameters
# To change a dataset to run the code on:
#   - Change the data.dataset_source to reflect which dataset you're trying to run.
#           This controls which data loading scripts to use and how to normalize
#   - Change the paths. For all datasets but binarized_mnist and cifar-10, define where the data lives on disk.
#   - Change the metadata: Define the image resolution, the number of channels and the color bit-depth of the data.

# Dataset source. Can be one of ('binarized_mnist', 'cifar-10', 'imagenet', 'celebA', 'celebAHQ', 'ffhq', 'custom')
dataset_source = 'custom'

# Data paths. Not used for (binarized_mnist, cifar-10)
train_data_path = '../datasets/my_custom_data/train_data/'
val_data_path = '../datasets/my_custom_data/val_data/'
synthesis_data_path = '../datasets/my_custom_data/synthesis_data/'

# Image metadata
# Image resolution of the dataset (High and Width, assumed square)
target_res = 512
# Image channels of the dataset (Number of color channels)
channels = 1
# Image color depth in the dataset (bit-depth of each color channel)
num_bits = 8.
# Whether to do a random horizontal flip of images when loading the data (no applicable to MNIST)
random_horizontal_flip = True

Obviously, also change the model section of the hparams.cfg to create a model that works well with your data resolution. When in doubt, get inspired by the example hparams in the egs folder.

Notes:

Potential TODOs

Bibtex

If you happen to use this codebase, please cite our paper:

@article{hazami2022efficient,
  title={Efficient-VDVAE: Less is more},
  author={Hazami, Louay and Mama, Rayhane and Thurairatnam, Ragavan},
  journal={arXiv preprint arXiv:2203.13751},
  year={2022}
}