tianweiy / DMD2

(NeurIPS 2024 Oral 🔥) Improved Distribution Matching Distillation for Fast Image Synthesis
Other
448 stars 25 forks source link

Question about gen_cls_loss and guidance_cls_loss #26

Closed weleen closed 2 months ago

weleen commented 2 months ago

Hi Tianwei,

Thank you for your wonderful work.

I have some questions about gen_cls_loss and guidance_cls_loss in the code, gen_cls_loss and gen_cls_loss_weight: https://github.com/tianweiy/DMD2/blob/0f8a481716539af7b2795740c9763a7d0d05b83b/main/sd_guidance.py#L323 https://github.com/tianweiy/DMD2/blob/0f8a481716539af7b2795740c9763a7d0d05b83b/main/train_sd.py#L372 guidance_cls_loss: https://github.com/tianweiy/DMD2/blob/0f8a481716539af7b2795740c9763a7d0d05b83b/main/sd_guidance.py#L391

Following your paper, I cannot find the definition of gen_cls_loss, and I find that gen_cls_loss_weight is set as the different value in sdv1.5 and sdxl (1e-3 vs. 5e-3), could you elaborate how to tune this hyperparameter?

And guidance_cls_loss is formulated as follows:

image

It seems that log is replaced by softplus in the code.

tianweiy commented 2 months ago

I cannot find the definition of gen_cls_loss,

It is the standard non-saturating gan loss. It is basically the second half of the L_GAN equation you sent above. I don't have much ideas about tuning the hyperparameter. For SDv1.5, it seems less stable with 5e-3 weight. I would suggest starting with something like 1e-2 and gradually reducing it to the one that produces the best metric. The 1e-2 may sound small but this is actually required as the loss_dmd looks large but its gradient norm is really small once it is back propagated as we have a mse formula which implicitly weights down the gradient

https://github.com/tianweiy/DMD2/blob/0f8a481716539af7b2795740c9763a7d0d05b83b/main/sd_guidance.py#L241

It seems that log is replaced by softplus in the code.

I think the math is equivalent. Please note the raw output from the network is not the predicted ratio. We had omitted a sigmoid here and if you combine sigmoid with the soft plus, it exactly gives us the log. This loss definition is followed from StyleGAN2 https://github.com/NVlabs/stylegan2-ada-pytorch/blob/d72cc7d041b42ec8e806021a205ed9349f87c6a4/training/loss.py#L71

weleen commented 2 months ago

Thanks!