rosinality / vq-vae-2-pytorch

Implementation of Generating Diverse High-Fidelity Images with VQ-VAE-2 in PyTorch
Other
1.6k stars 270 forks source link
vq-vae vq-vae-2

vq-vae-2-pytorch

Implementation of Generating Diverse High-Fidelity Images with VQ-VAE-2 in PyTorch

Update

train_vqvae.py and vqvae.py now supports distributed training. You can use --n_gpu [NUM_GPUS] arguments for train_vqvae.py to use [NUM_GPUS] during training.

Requisite

Checkpoint of VQ-VAE pretrained on FFHQ

Usage

Currently supports 256px (top/bottom hierarchical prior)

  1. Stage 1 (VQ-VAE)

python train_vqvae.py [DATASET PATH]

If you use FFHQ, I highly recommends to preprocess images. (resize and convert to jpeg)

  1. Extract codes for stage 2 training

python extract_code.py --ckpt checkpoint/[VQ-VAE CHECKPOINT] --name [LMDB NAME] [DATASET PATH]

  1. Stage 2 (PixelSNAIL)

python train_pixelsnail.py [LMDB NAME]

Maybe it is better to use larger PixelSNAIL model. Currently model size is reduced due to GPU constraints.

Sample

Stage 1

Note: This is a training sample

Sample from Stage 1 (VQ-VAE)