LuChengTHU / dpm-solver

Official code for "DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps" (Neurips 2022 Oral)
MIT License
1.52k stars 120 forks source link

Results of Imagenet256 using DPM++ not reproducible #47

Closed wonkyoc closed 10 months ago

wonkyoc commented 10 months ago

Problem In the DPM++ paper, the FID of parameters [s=8.0 | NFE=25] for Imagenet256 benchmark is 8.39; yet, this repo does not produce an approximate number (the result shows about FID ~= 200).

Env GPU: NVIDIA RTX 2080Ti * 4 Pytorch==1.13.1 cuda-11.7

config/imagenet256_guide.yml

# the rest of options is the same
sampling:
+batch_size: 10
+fid_total_samples: 10000

What I did I simply ran ddpm_and_guided-diffusion/sample.sh and commented out CIFAR10/Imagenet64 to only execute Imagenet256

Question Are there any parameters should be changed? I checked everything that mentioned in the paper or README in this repo.

LuChengTHU commented 10 months ago

Hi @wonkyoc ,

Thanks for following our work!

I've checked the script and config and I can successfully sample the images. Could you please show one of your sampled images here?

wonkyoc commented 10 months ago

Hi @LuChengTHU,

Thanks for the fast response and I am glad that there is no problem. Unfortunately, I accidentally deleted results for fid_total_samples: 10000 but I assume fid_total_samples: 1000 should be enough to figure a problem out. The following information is from what I have just re-run. (while you post a comment, I will start fid_total_samples: 10000 but it will takea a few hours w/ my GPus ;)

# sample.sh
data="imagenet256_guided"
scale="8.0"
sampleMethod='dpmsolver++' or 'dpmsolver'
type="dpmsolver"
steps="20"
DIS="time_uniform"
order="2"
method="multistep"

# config/imagenet256_guided.yml
sampling:
    total_N: 1000
    batch_size: 10
    last_only: True
    fid_stats_dir: "./fid_stats/VIRTUAL_imagenet256_labeled.npz"
    fid_total_samples: 100
    fid_batch_size: 100

DPM++2M dpm++ DPM2M dpm

Hope this can help!

LuChengTHU commented 10 months ago

@wonkyoc ,

Even if for samples like this, the FID cannot be that poor (~200). I guess the FID stats file you used may be incorrect. Could you please check it?

LuChengTHU commented 10 months ago

Hi @wonkyoc ,

I've checked carefully for the README and I feel so sorry for this typo: 52bc3fbcd5de56d60917b826b15d2b69460fc2fa

Is your bug corresponding to this incorrect FID stat?

wonkyoc commented 10 months ago

@LuChengTHU

I actually recognized that typo and used the right file for Imagenet256. I also agree that that image (or w/ others) cannot be produce that high FID. I suspect one of packages related to calculating FID seems to be matter...

pytorch_fid==0.3.0
scipy==1.9.1

I will investigate further.

LuChengTHU commented 10 months ago

@wonkyoc

Please try the pytorch 1.x instead of 0.x (I used torch==1.12.1)

wonkyoc commented 10 months ago

Oh, that's typo. I am using pytorch==1.13.1 and what I wanted to write was actually pytorch-fid=0.3.0. Anyway, luckily, I found an issue at my end. The problem is that I put torch.manual_seed(0) for SDE evaluation and this generates only a certain set of images repeatedly (due to randomness for samples) and it eventually does not calculate the proper FID. (i.e., 50 images are generated multiple times filling 10k images). I am closing the issue. Thanks for your help!

FYI, the generated FID for 10K is 8.06