sndnyang / iDDPM

My pipeline wrapper of 'Improved Denoising Diffusion Probabilistic Models'.
MIT License
17 stars 3 forks source link

improved-diffusion iDDPM

Thanks a lot for the official code https://github.com/openai/improved-diffusion

However, I fail to install mpi4py for multi/distributed training. And I think mpi4py and torch.distributed are not for prototyping. So I remove those components and the code should run on a single GPU without mpi4py.

I generate and save the generated images, and evaluate the IS/FID/KID during training. I keep a replay buffer to store generated images for IS/FID/KID

pip install torch  torch_fidelity

optional: wandb

Usage

Contains the code for: 1) Train, 2) Generate and FID, 3) NLL/BPD, namely negative log likelihood / bit per dim.

Training

The training command is simple.

python iddpm_main.py --gpu-id 0 --print-freq 100 --note warmup    [--no_fid]

The hyperparameters please refer to https://github.com/openai/improved-diffusion#models-and-hyperparameters and change iddpm/script_util.py.

Note

A larger diffusion_steps will cause a much longer generation during traing. You can use --no_fid to ignore the sampling.

My results

The figure of traing curves and logs

On wandb.ai

https://wandb.ai/sndnyang/biggest/reports/iDDPM--VmlldzoyNzQwMTQ4?accessToken=spxh6nc9asp4l1og8hoifeym9oxkx4j8av78hkyp06s9jy4gsjf7zt5yezg3gceo

The replay buffer's FID is ~10, while generating 10k images from scratch only achieves ~16.7

Sampling and Evaluate IS/FID/KID

I use torch_fidelity

To reload other models/other datasets, check iddpm/script_util.py and utils/eval_quality.py

python iddpm_eval.py --eval gen --resume path/your/checkpoint.pth --gpu-id 0

The comparison to TF https://github.com/sndnyang/inception_score_fid/blob/master/eval_is_fid_torch_fidelity.ipynb

DDIM

Also use their official implementation, just a command

python iddpm_eval.py --eval gen --use_ddim --resume cifar10_uncond_50M_500K.pt --gpu-id 0 --timestep_respacing 50

Speed: batch 10/1000 images takes 0:01:17.226265

Logs: logs/ddim_sampling_50.log

Evaluate NLL

negative log likelihood / bit per dim

python iddpm_eval.py --eval nll --resume path/your/checkpoint.pth --gpu-id 0

Note

Check Row 119-122 in iddpm_eval.py, make sure it's consistent with the trained checkpoint model.

        model_mean_type=ModelMeanType.EPSILON,
        model_var_type=ModelVarType.LEARNED,
        loss_type=LossType.MSE,

The evaluation is very slow~~~

Change Log

2022.10.31

Faster sampling, namely call --use_ddim

2022.10.30

Notebooks for

  1. Evaluate NLL
  2. Generate and FID

TODO

  1. Faster sampling (Not working on it)

2022.10.04

Evaluation

  1. Generation and Evaluate IS/FID/KID
  2. Evaluate NLL

TODO

  1. Faster sampling

2022.10.03

Init code

  1. Training iddpm_main.py