The official repository for <Autoencoding Under Normalization Constraints> (Yoon, Noh and Park, ICML 2021) and normalized autoencoders.
The paper proposes Normalized Autoencoder (NAE), which is a novel energy-based model where the energy function is the reconstruction error. NAE effectively remedies outlier reconstruction, a pathological phenomenon limiting the performance of an autoencoder as an outlier detector.
Paper: https://arxiv.org/abs/2105.05735
5-min video: https://www.youtube.com/watch?v=ra6usGKnPGk
I encourage you to use conda to set up a virtual environment. However, other methods should work without problems.
conda create -n nae python=3.7
The main dependencies of the repository is as follows:
All datasets are stored in datasets/
directory.
torchvision.dataset
.When set up, the dataset directory should look like as follows.
datasets
├── CelebA
│ ├── Anno
│ ├── Eval
│ └── Img
├── cifar-10-batches-py
├── const_img_gray.npy
├── const_img.npy
├── FashionMNIST
├── ImageNet32
│ ├── train_32x32
│ └── valid_32x32
├── MNIST
├── noise_img.npy
├── omniglot-py
│ ├── images_background
│ └── images_evaluation
├── test_32x32.mat
└── train_32x32.mat
Pre-trained models are stored under pretrained/
. The pre-trained models are provided through the Dropbox link.
If the pretrained models are prepared successfully, the directory structure should look like the following.
pretrained
├── celeba64_ood_nae
│ └── z64gr_h32g8
├── cifar_ood_nae
│ └── z32gn
└── mnist_ood_nae
└── z32
PyTest is used for unittesting.
pytest tests
The code should pass all tests after the preparation of pre-trained models and datasets.
python evaluate_ood.py --ood ConstantGray_OOD,FashionMNIST_OOD,SVHN_OOD,CelebA_OOD,Noise_OOD --resultdir pretrained/cifar_ood_nae/z32gn/ --ckpt nae_9.pkl --config z32gn.yml --device 0 --dataset CIFAR10_OOD
Use train.py
to train NAE.
--config
option specifies a path to a configuration yaml file.--logdir
specifies a directory where results files will be written.--run
specifies an id for each run, i.e., an experiment.Training on MNIST
python train.py --config configs/mnist_ood_nae/z32.yml --logdir results/mnist_ood_nae/ --run run --device 0
Training on MNIST digits 0 to 8 for the hold-out digit detection task
python train.py --config configs/mnist_ho_nae/l2_z32.yml --logdir results/mnist_ho_nae --run run --device 0
Training on CIFAR-10
python train.py --config configs/cifar_ood_nae/z32gn.yml --logdir results/cifar_ood_nae/ --run run --device 0
Training on CelebA 64x64
python train.py --config configs/celeba64_ood_nae/z64gr_h32g8.yml --logdir results/celeba64_ood_nae/z64gr_h32g8.yml --run run --device 0
Training on FashionMNIST
python train.py --config configs/fmnist_ood_nae/z32.yml --device 0 --logdir results/fmnist_ood_nae --run run --device 0
Use sample.py
to generate sample images form NAE. Samples are saved as .npy
file containing an (n_sample, img_h, img_w, channels)
array.
Note that the quality of generated images is not supposed to match that of state-of-the-art generative models. Improving the sample quality is one of the important future research direction.
Sampling for MNIST
python sample.py pretrained/mnist_ood_nae/z32/ z32.yml nae_20.pkl --zstep 200 --x_shape 28 --batch_size 64 --n_sample 64 --x_channel 1 --device 0
The white square is an artifact of NAE, possibly occurring due to the distortion of the encoder and the decoder.
The result is comparable to the samples from a vanilla autoencoder generated with the same procedure.
Sampling for CIFAR-10
python sample.py pretrained/cifar_ood_nae/z32gn/ z32gn.yml nae_8.pkl --zstep 180 --xstep 40 --batch_size 64 --n_sample 64 --name run --device 0
Sampling for CelebA 64x64
python sample.py pretrained/celeba64_ood_nae/z64gr_h32g8/ z64gr_h32g8.yml nae_3.pkl --zstep 180 --xstep 40 --batch_size 64 --n_sample 64 --name run --device 0 --x_shape 64
@InProceedings{pmlr-v139-yoon21c,
title = {Autoencoding Under Normalization Constraints},
author = {Yoon, Sangwoong and Noh, Yung-Kyun and Park, Frank},
booktitle = {Proceedings of the 38th International Conference on Machine Learning},
pages = {12087--12097},
year = {2021},
editor = {Meila, Marina and Zhang, Tong},
volume = {139},
series = {Proceedings of Machine Learning Research},
month = {18--24 Jul},
publisher = {PMLR},
pdf = {http://proceedings.mlr.press/v139/yoon21c/yoon21c.pdf},
url = {https://proceedings.mlr.press/v139/yoon21c.html}
}