crowsonkb / k-diffusion

Karras et al. (2022) diffusion models for PyTorch
MIT License
2.26k stars 372 forks source link

FID scores from paper #51

Open darius-lam opened 1 year ago

darius-lam commented 1 year ago

Hello, in the original K-Diffusion paper the authors report FID scores for CIFAR in the low-single-digits range (eg 1.8). However, the FID scores from this repo all give in the high teens: like 27, 36, 48.

Is the difference due to the FID calculation strategy? It's hard to imagine that the FID is over an order of magnitude off...

lhaippp commented 1 year ago

Got the same issue on cifar10, based on settings (bs=1024):

step,fid,kid
10000,54.75194549560547,0.024161577224731445
20000,38.80967330932617,0.01055598258972168
30000,35.451629638671875,0.008167028427124023
40000,32.956600189208984,0.005323648452758789
50000,32.06228256225586,0.004504680633544922
madebyollin commented 1 year ago

I was curious about this too... quoting https://arxiv.org/pdf/1801.01401.pdf:

FID scores can only be compared to one another with the same value of n.

If you compare the CIFAR-10 test and train sets (same distribution!), at n=2000, you'd get an FID around 30. 2000 is the default evaluate-n for k-diffusion training, so I think it's expected that the logged FID values will be 30 or higher.

image

The EDM paper reports results with n_fake=50000 and n_real="all available"; this FID value will be much lower than any n=2000 FIDs.

image

To check the 50k FID value for k-diffusion, I tried training the default CIFAR-10 config in this repo for 320k steps (checkpoint), and sampled 50k images:

mkdir -p cifar10_fake_samples
python3 sample.py --checkpoint model_cifar10_last.pth --prefix cifar10_fake_samples/sample --config configs/config_cifar10.json -n 50000

model_demo_00320000

I then followed the evaluation steps from the EDM repo

cd ..; git clone https://github.com/NVlabs/edm; cd edm
torchrun --standalone --nproc_per_node=1 fid.py calc --images=../k-diffusion/cifar10_fake_samples/ --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz

This gave me an FID value of ~4.6, which seems to be on the right order of magnitude:

Calculating statistics for 50000 images...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 782/782 [01:05<00:00, 12.01batch/s]
Calculating FID...
4.5546

I expect more training (and stronger dropout / augmentation) is needed to match the ~2.0 FID values reported in the EDM paper without overfitting.

shonenkov commented 1 year ago

https://github.com/crowsonkb/k-diffusion/pull/78

WayneDW commented 9 months ago

I am having similar issues. The FID score is pretty high.