forever208 / DDPM-IP

[ICML 2023] official implementation for "Input Perturbation Reduces Exposure Bias in Diffusion Models"
MIT License
100 stars 9 forks source link

PWC PWC PWC PWC

DDPM-IP

This is the codebase for the ICML 2023 paper Input Perturbation Reduces Exposure Bias in Diffusion Models.
This repository is heavily based on openai/guided-diffusion, with training modification of input perturbation.

Also, feel free to check out our ICLR 2024 paper Elucidating the Exposure Bias in Diffusion Models which introduces a simple training-free solution to exposure bias. Repository: ADM-ES and EDM-ES

Simple to implement Input Perturbation in diffusion models

Our proposed Input Perturbation is an extremely simple plug-in method for general diffusion models. The implementation of Input Perturbation is just two lines of code.

For instance, based on guided-diffusion, the only code modifications are in the script guided_diffusion/gaussian_diffusion.py, in line 765-766:

new_noise = noise + gamma * th.randn_like(noise)  # gamma=0.1
x_t = self.q_sample(x_start, t, noise=new_noise)

NOTE THAT: change the parameter GPUS_PER_NODE = 4 in the script dist_util.py according to your GPU cluster configuration.

Installation

the installation is the same with guided-diffusion

git clone https://github.com/forever208/DDPM-IP.git
cd DDPM-IP
conda create -n ADM python=3.8
conda activate ADM
pip install -e .
(note that, pytorch 1.10~1.13 is recommended as our experiments in paper were done with pytorch 1.10 and pytorch 2.0 has not been tested by us in this repo)

# install the missing packages
conda install mpi4py
conda install numpy
pip install Pillow
pip install opencv-python

Download ADM-IP models and ADM base models

We have released checkpoints for the main models in the paper.

(The baseline checkpoint of ImageNet-32 and CelebA-64 are missing due to unexpected server file deletion. If you have trained the ADM base models, welcome to share the checkpoints)

Here are the download links for model checkpoints:

Sampling from pre-trained ADM-IP models

To unconditionally sample from these models, you can use the image_sample.py scripts. Sampling from DDPM-IP has no difference with sampling from openai/guided-diffusion since DDPM-IP does not change the sampling process.

For example, we sample 50k images using 100 steps from CIFAR10 by:

mpirun python scripts/image_sample.py \
--image_size 32 --timestep_respacing 100 \
--model_path PATH_TO_CHECKPOINT \
--num_channels 128 --num_head_channels 32 --num_res_blocks 3 --attention_resolutions 16,8 \
--resblock_updown True --use_new_attention_order True --learn_sigma True --dropout 0.3 \
--diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True --batch_size 256 --num_samples 50000

sample 50k images using 100 steps from LSUN_tower by:

mpirun -n 1 python scripts/image_sample.py \
--image_size 64 --timestep_respacing 100 \
--model_path PATH_TO_CHECKPOINT \
--use_fp16 True --num_channels 192 --num_head_channels 64 --num_res_blocks 3 \
--attention_resolutions 32,16,8 --resblock_updown True --use_new_attention_order True \
--learn_sigma True --dropout 0.1 --diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True \
--rescale_learned_sigmas True --batch_size 256 --num_samples 50000

sample 50k images using 100 steps from FFHQ128 by:

mpirun -n 1 python scripts/image_sample.py \
--image_size 128 --timestep_respacing 100 \
--model_path PATH_TO_CHECKPOINT \
--use_fp16 True --num_channels 256 --num_head_channels 64 --num_res_blocks 3 \
--attention_resolutions 32,16,8 --resblock_updown True --use_new_attention_order True \
--learn_sigma True --dropout 0.1 --diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True \
--rescale_learned_sigmas True --batch_size 128 --num_samples 50000

Results

This table summarizes our input perturbation results based on ADM baselines. Input perturbation shows tremendous training acceleration and much better FID results.

FID computation details:

This table summarizes our input perturbation results based on DDIM baselines.

Prepare datasets

Please refer to README.md for the data preparation.

Training ADM-IP

Training diffusion models are described in this repository.

Training ADM-IP only requires one more argument --input perturbation 0.1 (set --input perturbation 0.0 for the baseline).

NOTE THAT: if you have problems with slurm multi-node training, try the following setting. Let's say training by 16 GPUs on 2 nodes:

#SBATCH --nodes=2
#SBATCH --ntasks-per-node=8
#SBATCH --cpus-per-task=6
#SBATCH --gres=gpu:8 # 8 gpus for each node

instead of specifying mpiexec -n 16, you run by mpirun python script/image_train.py. (more discussion can be found here)

We share the complete arguments of training ADM-IP in the four datasets:

CIFAR10

mpiexec -n 2  python scripts/image_train.py --input_pertub 0.15 \
--data_dir PATH_TO_DATASET \
--image_size 32 --use_fp16 True --num_channels 128 --num_head_channels 32 --num_res_blocks 3 \
--attention_resolutions 16,8 --resblock_updown True --use_new_attention_order True \
--learn_sigma True --dropout 0.3 --diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True \
--rescale_learned_sigmas True --schedule_sampler loss-second-moment --lr 1e-4 --batch_size 64

ImageNet 32x32 (you can also choose dropout=0.1)

mpiexec -n 4  python scripts/image_train.py --input_pertub 0.1 \
--data_dir PATH_TO_DATASET \
--image_size 32 --use_fp16 True --num_channels 128 --num_head_channels 32 --num_res_blocks 3 \
--attention_resolutions 16,8 --resblock_updown True --use_new_attention_order True \
--learn_sigma True --dropout 0.3 --diffusion_steps 1000 --noise_schedule cosine \
--rescale_learned_sigmas True --schedule_sampler loss-second-moment --lr 1e-4 --batch_size 128

LSUN tower 64x64

mpiexec -n 16  python scripts/image_train.py --input_pertub 0.1 \
--data_dir PATH_TO_DATASET \
--image_size 64 --use_fp16 True --num_channels 192 --num_head_channels 64 --num_res_blocks 3 \
--attention_resolutions 32,16,8 --resblock_updown True --use_new_attention_order True \
--learn_sigma True --dropout 0.1 --diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True \
--rescale_learned_sigmas True --schedule_sampler loss-second-moment --lr 1e-4 --batch_size 16

CelebA 64x64

mpiexec -n 16  python scripts/image_train.py --input_pertub 0.1 \
--data_dir PATH_TO_DATASET \
--image_size 64 --use_fp16 True --num_channels 192 --num_head_channels 64 --num_res_blocks 3 \
--attention_resolutions 32,16,8 --resblock_updown True --use_new_attention_order True \
--learn_sigma True --dropout 0.1 --diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True \
--rescale_learned_sigmas True --schedule_sampler loss-second-moment --lr 1e-4 --batch_size 16

FFHQ 128x128

mpirun -n 16 python scripts/image_train.py --input_pertub 0.1 \
--data_dir PATH_TO_DATASET \
--image_size 128 --use_fp16 True --num_channels 256 --num_head_channels 64 --num_res_blocks 3 \
--attention_resolutions 32,16,8 --resblock_updown True --use_new_attention_order True \
--learn_sigma True --dropout 0.1 --diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True \
--rescale_learned_sigmas True --schedule_sampler loss-second-moment --lr 1e-4 --batch_size 8

Citation

If you find our work useful, please feel free to cite by

@inproceedings{ning2023input,
  title={Input Perturbation Reduces Exposure Bias in Diffusion Models},
  author={Ning, Mang and Sangineto, Enver and Porrello, Angelo and Calderara, Simone and Cucchiara, Rita},
  booktitle={International Conference on Machine Learning},
  pages={26245--26265},
  year={2023},
  organization={PMLR}
}
@inproceedings{ningelucidating,
  title={Elucidating the Exposure Bias in Diffusion Models},
  author={Ning, Mang and Li, Mingxiao and Su, Jianlin and Salah, Albert Ali and Ertugrul, Itir Onal},
  booktitle={The Twelfth International Conference on Learning Representations}
}