I'm new to the consistency model and trying to reproduce the FID curve in your paper(Fig. 3d) with PyTorch. However, when using the configuration in this repo. I found the FID does not converge like Fig. 3d. I think this might be a reason for the following configuration:
warmup: I use no warmup following the paper. I tried to use a 10M warmup like this repo, but the FID converges much slower than Fig 3d.
learning rate(I adopt 4e-4)
ema_weight to generate samples(following paper I use 0.9999)
FIR resampling: I did not implement FIR kernel in my implementation
Fourier Pos Embedding: I use sin/cos pos embedding in my implementation
I'm quite sure I made the right implementation for other components (like adaptive u, adaptive T, and skip factor c).
I reached an FID of 12.84 at CIFAR-10 using clean-fid's pytorch-legacy statistics.
Do you think FIR and Fourier's pos embedding is important for reproducing the result in the paper?
What is the configuration for reproducing the FID curve in Fig. 3d?
I use the u-net model in your PyTorch repo, and the configuration following this repo, with
use_scale_shift_norm=True
resblock_updown=True
num_head_channels=64
Thanks for your innovative work!
I'm new to the consistency model and trying to reproduce the FID curve in your paper(Fig. 3d) with PyTorch. However, when using the configuration in this repo. I found the FID does not converge like Fig. 3d. I think this might be a reason for the following configuration:
I'm quite sure I made the right implementation for other components (like adaptive u, adaptive T, and skip factor c).
I reached an FID of 12.84 at CIFAR-10 using clean-fid's pytorch-legacy statistics.
Do you think FIR and Fourier's pos embedding is important for reproducing the result in the paper? What is the configuration for reproducing the FID curve in Fig. 3d?
Thanks again for your generous contributions!