Closed KaiWU5 closed 8 months ago
In the code
AddAuxiliaryLoss
, the loss is not stored or used in the forward function, does that mean the grad is constantly to be 1? should it be grad_output * loss?Thanks a lot if you can straighten this out for me.
class AddAuxiliaryLoss(torch.autograd.Function): """ The trick function of adding auxiliary (aux) loss, which includes the gradient of the aux loss during backpropagation. """ @staticmethod def forward(ctx, x, loss): assert loss.numel() == 1 ctx.dtype = loss.dtype ctx.required_aux_loss = loss.requires_grad return x @staticmethod def backward(ctx, grad_output): grad_loss = None if ctx.required_aux_loss: grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) return grad_output, grad_loss
@DeepSeekDDM @zwd003 Thanks a lot for helping.
For the aux loss, the gradient with respect to itself is always 1. This is equivalent to adding the aux loss to the final loss.
In the code
AddAuxiliaryLoss
, the loss is not stored or used in the forward function, does that mean the grad is constantly to be 1? should it be grad_output * loss?Thanks a lot if you can straighten this out for me.
@DeepSeekDDM @zwd003 Thanks a lot for helping.