Correct me if I am wrong. The grad scale of SDS loss should be multiplied by the loss instead of grad.
With the SDS loss formulated as loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0], the grad scale should applied to the loss by loss = loss * grad_scale. Instead of multiply it on the grad before calculating the target (here).
Description
Thanks a lot for your brilliant work.
Correct me if I am wrong. The grad scale of SDS loss should be multiplied by the loss instead of grad.
With the SDS loss formulated as
loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]
, the grad scale should applied to the loss byloss = loss * grad_scale
. Instead of multiply it on the grad before calculating the target (here).Steps to Reproduce
NA
Expected Behavior
NA
Environment
NA