NVlabs / edm

Elucidating the Design Space of Diffusion-Based Generative Models (EDM)
Other
1.27k stars 135 forks source link

Question about Fourier Embedding and EDM Loss calculation #27

Closed zhengyu-su closed 2 months ago

zhengyu-su commented 2 months ago

Hi, I am trying to use the pre-trained cifar10‑32x32‑cond‑ve for a fine-tuning task. During implementation I encounter some errors during loss calculation: https://github.com/NVlabs/edm/blob/008a4e5316c8e3bfe61a62f874bddba254295afb/training/loss.py#L66-L80

For mapping the noise sigma, the following calculation is used: https://github.com/NVlabs/edm/blob/008a4e5316c8e3bfe61a62f874bddba254295afb/training/networks.py#L217-L220 This torch.ger() requires input to be an 1-D vector, but as randomly initialized in the EDM loss, it has shape [image.shape[0],1,1,1], which is a 4-D vector and cannot pass through it. I wanted to ask if this is a small bug of the code or do I do something somewhere wrong to get this error?

Another question is regarding the scaling of the loss. As the weight is also with randomness and sometimes get very high, how do you choose the scaling factor for this in the training loop? Even though the prediction and input may be similar, the loss can be very huge..

Could someone please take a look and let me know what might be going wrong? Thank you in advance for your help!

zhengyu-su commented 2 months ago

I just found out if the network is generated with pre conditions the noise input got flattened during their forward pass.