openai / consistency_models_cifar10

Consistency models trained on CIFAR-10, in JAX.
Apache License 2.0
142 stars 16 forks source link

Configuration to reproduce the cifar-10 ct adaptive training result in the paper. #4

Open Zyriix opened 1 year ago

Zyriix commented 1 year ago

Thanks for your innovative work!

I'm new to the consistency model and trying to reproduce the FID curve in your paper(Fig. 3d) with PyTorch. However, when using the configuration in this repo. I found the FID does not converge like Fig. 3d. I think this might be a reason for the following configuration:

  1. warmup: I use no warmup following the paper. I tried to use a 10M warmup like this repo, but the FID converges much slower than Fig 3d.
  2. learning rate(I adopt 4e-4)
  3. ema_weight to generate samples(following paper I use 0.9999)
  4. FIR resampling: I did not implement FIR kernel in my implementation
  5. Fourier Pos Embedding: I use sin/cos pos embedding in my implementation

I'm quite sure I made the right implementation for other components (like adaptive u, adaptive T, and skip factor c).

I reached an FID of 12.84 at CIFAR-10 using clean-fid's pytorch-legacy statistics.

Do you think FIR and Fourier's pos embedding is important for reproducing the result in the paper? What is the configuration for reproducing the FID curve in Fig. 3d?

Thanks again for your generous contributions!

Zyriix commented 1 year ago

I use the u-net model in your PyTorch repo, and the configuration following this repo, with use_scale_shift_norm=True resblock_updown=True num_head_channels=64