chenlin9 / Autoregressive-Modeling-with-Distribution-Smoothing

9 stars 1 forks source link

Improved Autoregressive Modeling with Distribution Smoothing

This repo contains the implementation for the paper Improved Autoregressive Modeling with Distribution Smoothing

by Chenlin Meng, Jiaming Song, Yang Song, Shengjia Zhao and Stefano Ermon, Stanford AI Lab.

Running Experiments

Dependencies

Run the following to install all necessary python packages for our code.

pip install -r requirements.txt

Stage1: Learning the smoothed distribution

To train the PixelCNN++ model on the smoothed distribution for CIFAR-10, run:

python main.py --runner SmoothedPixelCNNPPTrainRunner --config pixelcnnpp_smoothed_train_cifar10.yml --doc cifar10_smoothed_0.3 --ni

Stage2: Reverse smoothing

To reverse the smoothing process, we train a second PixelCNN++ model conditioned on the smoothed distribution. To train the model on CIFAR-10, run:

python main.py --runner SmoothedPixelCNNPPTrainRunner --config pixelcnnpp_conditioned_train_cifar10.yml --doc reverse_cifar10_0.3 --ni

Sampling

Sampling from stage 1:

pixelcnnpp_smoothed_sample.yml needs to be modified.

ckpt_path: path to the model trained on the smoothed data in stage 1.

The dataset parameter might also need to be modified accordingly. Selections are MNIST, CIFAR10, or celeba.

python main.py --runner PixelCNNPPSamplerRunner --config pixelcnnpp_smoothed_sample.yml --doc cifar10_0.3_images

Sampling from stage 2:

pixelcnnpp_reverse_sample.yml needs to be modified.

noisy_samples_path: path to the noisy samples generated by the model trained on the smoothed data in stage 1,

ckpt_path: path to the reverse smoothing model in stage 2.

The dataset parameter might need to be changed accordingly. Selections are MNIST, CIFAR10, or celeba.

python main.py --runner PixelCNNPPSamplerRunner --config pixelcnnpp_reverse_sample.yml --doc cifar10_denoise_images

References and Acknowledgements

@article{meng2021improved,
  title={Improved Autoregressive Modeling with Distribution Smoothing},
  author={Meng, Chenlin and Song, Jiaming and Song, Yang and Zhao, Shengjia and Ermon, Stefano},
  journal={arXiv preprint arXiv:2103.15089},
  year={2021}
}

This implementation is based on / inspired by: