This is the codebase for the ICML 2023 paper Input Perturbation Reduces Exposure Bias in Diffusion Models.
This repository is heavily based on openai/guided-diffusion, with training modification of input perturbation.
Also, feel free to check out our ICLR 2024 paper Elucidating the Exposure Bias in Diffusion Models which introduces a simple training-free solution to exposure bias. Repository: ADM-ES and EDM-ES
Our proposed Input Perturbation is an extremely simple plug-in method for general diffusion models. The implementation of Input Perturbation is just two lines of code.
For instance, based on guided-diffusion, the only code modifications are in the script guided_diffusion/gaussian_diffusion.py, in line 765-766:
new_noise = noise + gamma * th.randn_like(noise) # gamma=0.1
x_t = self.q_sample(x_start, t, noise=new_noise)
NOTE THAT: change the parameter GPUS_PER_NODE = 4
in the script dist_util.py
according to your GPU cluster configuration.
the installation is the same with guided-diffusion
git clone https://github.com/forever208/DDPM-IP.git
cd DDPM-IP
conda create -n ADM python=3.8
conda activate ADM
pip install -e .
(note that, pytorch 1.10~1.13 is recommended as our experiments in paper were done with pytorch 1.10 and pytorch 2.0 has not been tested by us in this repo)
# install the missing packages
conda install mpi4py
conda install numpy
pip install Pillow
pip install opencv-python
We have released checkpoints for the main models in the paper.
(The baseline checkpoint of ImageNet-32 and CelebA-64 are missing due to unexpected server file deletion. If you have trained the ADM base models, welcome to share the checkpoints)
Here are the download links for model checkpoints:
To unconditionally sample from these models, you can use the image_sample.py
scripts.
Sampling from DDPM-IP has no difference with sampling from openai/guided-diffusion
since DDPM-IP does not change the sampling process.
For example, we sample 50k images using 100 steps from CIFAR10 by:
mpirun python scripts/image_sample.py \
--image_size 32 --timestep_respacing 100 \
--model_path PATH_TO_CHECKPOINT \
--num_channels 128 --num_head_channels 32 --num_res_blocks 3 --attention_resolutions 16,8 \
--resblock_updown True --use_new_attention_order True --learn_sigma True --dropout 0.3 \
--diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True --batch_size 256 --num_samples 50000
sample 50k images using 100 steps from LSUN_tower by:
mpirun -n 1 python scripts/image_sample.py \
--image_size 64 --timestep_respacing 100 \
--model_path PATH_TO_CHECKPOINT \
--use_fp16 True --num_channels 192 --num_head_channels 64 --num_res_blocks 3 \
--attention_resolutions 32,16,8 --resblock_updown True --use_new_attention_order True \
--learn_sigma True --dropout 0.1 --diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True \
--rescale_learned_sigmas True --batch_size 256 --num_samples 50000
sample 50k images using 100 steps from FFHQ128 by:
mpirun -n 1 python scripts/image_sample.py \
--image_size 128 --timestep_respacing 100 \
--model_path PATH_TO_CHECKPOINT \
--use_fp16 True --num_channels 256 --num_head_channels 64 --num_res_blocks 3 \
--attention_resolutions 32,16,8 --resblock_updown True --use_new_attention_order True \
--learn_sigma True --dropout 0.1 --diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True \
--rescale_learned_sigmas True --batch_size 128 --num_samples 50000
This table summarizes our input perturbation results based on ADM baselines. Input perturbation shows tremendous training acceleration and much better FID results.
FID computation details:
This table summarizes our input perturbation results based on DDIM baselines.
Please refer to README.md for the data preparation.
Training diffusion models are described in this repository.
Training ADM-IP only requires one more argument --input perturbation 0.1
(set --input perturbation 0.0
for the baseline).
NOTE THAT: if you have problems with slurm multi-node training, try the following setting. Let's say training by 16 GPUs on 2 nodes:
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=8
#SBATCH --cpus-per-task=6
#SBATCH --gres=gpu:8 # 8 gpus for each node
instead of specifying mpiexec -n 16
, you run by mpirun python script/image_train.py
. (more discussion can be found here)
We share the complete arguments of training ADM-IP in the four datasets:
CIFAR10
mpiexec -n 2 python scripts/image_train.py --input_pertub 0.15 \
--data_dir PATH_TO_DATASET \
--image_size 32 --use_fp16 True --num_channels 128 --num_head_channels 32 --num_res_blocks 3 \
--attention_resolutions 16,8 --resblock_updown True --use_new_attention_order True \
--learn_sigma True --dropout 0.3 --diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True \
--rescale_learned_sigmas True --schedule_sampler loss-second-moment --lr 1e-4 --batch_size 64
ImageNet 32x32 (you can also choose dropout=0.1)
mpiexec -n 4 python scripts/image_train.py --input_pertub 0.1 \
--data_dir PATH_TO_DATASET \
--image_size 32 --use_fp16 True --num_channels 128 --num_head_channels 32 --num_res_blocks 3 \
--attention_resolutions 16,8 --resblock_updown True --use_new_attention_order True \
--learn_sigma True --dropout 0.3 --diffusion_steps 1000 --noise_schedule cosine \
--rescale_learned_sigmas True --schedule_sampler loss-second-moment --lr 1e-4 --batch_size 128
LSUN tower 64x64
mpiexec -n 16 python scripts/image_train.py --input_pertub 0.1 \
--data_dir PATH_TO_DATASET \
--image_size 64 --use_fp16 True --num_channels 192 --num_head_channels 64 --num_res_blocks 3 \
--attention_resolutions 32,16,8 --resblock_updown True --use_new_attention_order True \
--learn_sigma True --dropout 0.1 --diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True \
--rescale_learned_sigmas True --schedule_sampler loss-second-moment --lr 1e-4 --batch_size 16
CelebA 64x64
mpiexec -n 16 python scripts/image_train.py --input_pertub 0.1 \
--data_dir PATH_TO_DATASET \
--image_size 64 --use_fp16 True --num_channels 192 --num_head_channels 64 --num_res_blocks 3 \
--attention_resolutions 32,16,8 --resblock_updown True --use_new_attention_order True \
--learn_sigma True --dropout 0.1 --diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True \
--rescale_learned_sigmas True --schedule_sampler loss-second-moment --lr 1e-4 --batch_size 16
FFHQ 128x128
mpirun -n 16 python scripts/image_train.py --input_pertub 0.1 \
--data_dir PATH_TO_DATASET \
--image_size 128 --use_fp16 True --num_channels 256 --num_head_channels 64 --num_res_blocks 3 \
--attention_resolutions 32,16,8 --resblock_updown True --use_new_attention_order True \
--learn_sigma True --dropout 0.1 --diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True \
--rescale_learned_sigmas True --schedule_sampler loss-second-moment --lr 1e-4 --batch_size 8
If you find our work useful, please feel free to cite by
@inproceedings{ning2023input,
title={Input Perturbation Reduces Exposure Bias in Diffusion Models},
author={Ning, Mang and Sangineto, Enver and Porrello, Angelo and Calderara, Simone and Cucchiara, Rita},
booktitle={International Conference on Machine Learning},
pages={26245--26265},
year={2023},
organization={PMLR}
}
@inproceedings{ningelucidating,
title={Elucidating the Exposure Bias in Diffusion Models},
author={Ning, Mang and Li, Mingxiao and Su, Jianlin and Salah, Albert Ali and Ertugrul, Itir Onal},
booktitle={The Twelfth International Conference on Learning Representations}
}