AlexiaJM / score_sde_fast_sampling

Repository for the "Gotta Go Fast When Generating Data with Score-Based Models" paper
104 stars 8 forks source link

Running your code on FFHQ 1024 #4

Open yaseryacoob opened 2 years ago

yaseryacoob commented 2 years ago

Thanks for sharing your code, I am trying to compare the performance of your code to Songs on the FFHQ1024x1024 data. You seem to have a config file ffhq.py that might do that. I am not sure how to test this, I also see the need for ffhq-r10.tfrecords and not sure on the checkpoint and options for the call.

Can you please provide suggestions? thanks

AlexiaJM commented 2 years ago

Hi Yaser,

You can use the lines of code here https://github.com/AlexiaJM/score_sde_fast_sampling/blob/5da8f3fe103ee5ac3c3a336f16cc06c9541f0ed9/experiments.sh#L165 to compare the various different sampling methods on FFHQ-256. You'll want to replace ffhq_256.py with ffhq.py in the lines of code.

Based on the paper, for optimal results, you'll want to use this line: https://github.com/AlexiaJM/score_sde_fast_sampling/blob/5da8f3fe103ee5ac3c3a336f16cc06c9541f0ed9/experiments.sh#L170

You can download FFHQ from https://github.com/NVlabs/ffhq-dataset.

yaseryacoob commented 2 years ago

Alexia, thanks for the quick response. Let's clarify

  1. I was planning to just run the evaluation part to compare performance with respect to Song's pytorch checkpoint_60.pth (for FFHQ1kx1K). So I am not sure I understand why FFHQ data is needed for this.
  2. If I understand your comments, I will have to train on the FFHQ, before I get to step 1? I didn't see any reference to a pretrained check point (I assumed checkpoint_60 might be used with your SDE solver).
  3. I have the FFHQ 1K1K dataset but not the multi-resolution tfrecords. I would expect the training on Full FFHQ to take enormous amount of time?

thanks for further clarifications.

AlexiaJM commented 2 years ago
  1. You need the real data to calculate the FID. Assuming you already have the FID and don't have eval.enable_loss=False, you could just comment out the line https://github.com/AlexiaJM/score_sde_fast_sampling/blob/5da8f3fe103ee5ac3c3a336f16cc06c9541f0ed9/run_lib.py#L234 and things should work without the dataset. But the problem is that I have never computed that FID statistics, so you will need to compute the FID stats with the correct dataset (tfrecords).
  2. You do not need to train, it is only used for the FID calculation. You can use the checkpoint https://drive.google.com/drive/folders/1GwcthBS4Ry54eA_fIg1hOCfThQ6I3u1L and change workdir to 'ffhq_1024_ncsnpp_continuous' which is the folder containing the checkpoint
yaseryacoob commented 2 years ago

Thanks Alexia, I tried this route, but the ml-collections and absl seem to have some issues beyond my knowledge. I will let go of it for now and revisit it in the future.