kamenbliznashki / pixel_models

Pytorch implementations of autoregressive pixel models - PixelCNN, PixelCNN++, PixelSNAIL
36 stars 8 forks source link
autoregressive deep-learning generative-models pytorch

Pixel models

Implementations of autoregressive algorithms from:

Usage

The architecture files pixelcnn.py, pixelcnnpp.py, and pixelsnail.py contain model classes, loss function and generation function; optim.py implements an exponential moving average wrapper around torch optimizers; main.py contains the common logic around training, evaluation, and generation.

To train a model:

python main.py --train
               --dataset      # choice from cifar10, mnist, colored-mnist
               --data_path    # path to dataset
               --[add'l options]
               --model        # choice from pixelcnn, pixelcnnpp, pixelsnail;
                              # activates subparser for specific model params

Additional options are in the main.py parser arguments:

To evaluate a model or generate from a model:

python main.py --generate     # [evaluate]; if evaluate, need to specify dataset and data_path
               --restore_file # path to .pt checkpoint
               --model        # choice from pixelcnn, pixelcnnpp, pixelsnail

Results

Autoregressive models are particularly computationally intensive. I tested the above on a single batch of CIFAR10 and MNIST. I have not tried to replicate the published results since I only needed these as building blocks in other models.

Datasets

For colored MNIST see Berkeley's CS294-158; the dataset can be downloaded here.

Useful resources

Tensorflow implementations by the authors of PixelCNN++ and PixelSNAIL

Dependencies