tqch / ddpm-torch

Unofficial PyTorch Implementation of Denoising Diffusion Probabilistic Models (DDPM)
MIT License
197 stars 35 forks source link

FID for CIFAR10 checkpoint #13

Closed tj-14 closed 1 year ago

tj-14 commented 1 year ago

Hi, I tried generating samples with the example command

python generate.py --dataset cifar10 --chkpt-path ./chkpts/cifar10/cifar10_2040.pt --use-ddim --skip-schedule quadratic --subseq-size 100 --suffix _ddim --num-gpus 4

using the provided checkpoint and calculate the fid using the provided script. The FID I got is only 4.346242031101042 which is different than the reported evaluated metrics table. I'm wondering how to use the checkpoint to achieve the reported number. Thanks!

tqch commented 1 year ago

Hi there! It seems that the example command you were using is the one with a DDIM sampler, which is a fast (10x faster in this case; 100 sampling steps vs. 1000) sampler proposed by another work Denoising Diffusion Implicit Models (DDIM) [^1]. If you would like to reproduce the results of DDPM reported in this repo, please consider removing the use-ddim flag.

By the way, your DDIM result ($\approx$ 4.35) is actually within the reasonable range from the reported CIFAR-10 FID (with 100-steps and quadratic skips) of the DDIM paper. See below image and also the author's reply (implying a FID of 4.26 with a different seed in the same setting) on GitHub https://github.com/ermongroup/ddim/issues/3#issuecomment-960431667 image

[^1]: Song, Jiaming, Chenlin Meng, and Stefano Ermon. "Denoising Diffusion Implicit Models." In International Conference on Learning Representations. 2020.

tj-14 commented 1 year ago

Thank you for your clarification!

Edit: I reran with DDPM at 1000 subseq-size and got an FID of ~3.2 as reported in the table