deepseek-ai / DeepSeek-MoE

DeepSeekMoE: Towards Ultimate Expert Specialization in Mixture-of-Experts Language Models
MIT License
982 stars 48 forks source link

Question about AddAuxiliaryLoss? #17

Closed KaiWU5 closed 8 months ago

KaiWU5 commented 8 months ago

In the code AddAuxiliaryLoss, the loss is not stored or used in the forward function, does that mean the grad is constantly to be 1? should it be grad_output * loss?

Thanks a lot if you can straighten this out for me.

class AddAuxiliaryLoss(torch.autograd.Function):
    """
    The trick function of adding auxiliary (aux) loss,
    which includes the gradient of the aux loss during backpropagation.
    """
    @staticmethod
    def forward(ctx, x, loss):
        assert loss.numel() == 1
        ctx.dtype = loss.dtype
        ctx.required_aux_loss = loss.requires_grad
        return x

    @staticmethod
    def backward(ctx, grad_output):
        grad_loss = None
        if ctx.required_aux_loss:
            grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
        return grad_output, grad_loss

@DeepSeekDDM @zwd003 Thanks a lot for helping.

Hunter-DDM commented 8 months ago

In the code AddAuxiliaryLoss, the loss is not stored or used in the forward function, does that mean the grad is constantly to be 1? should it be grad_output * loss?

Thanks a lot if you can straighten this out for me.

class AddAuxiliaryLoss(torch.autograd.Function):
    """
    The trick function of adding auxiliary (aux) loss,
    which includes the gradient of the aux loss during backpropagation.
    """
    @staticmethod
    def forward(ctx, x, loss):
        assert loss.numel() == 1
        ctx.dtype = loss.dtype
        ctx.required_aux_loss = loss.requires_grad
        return x

    @staticmethod
    def backward(ctx, grad_output):
        grad_loss = None
        if ctx.required_aux_loss:
            grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
        return grad_output, grad_loss

@DeepSeekDDM @zwd003 Thanks a lot for helping.

For the aux loss, the gradient with respect to itself is always 1. This is equivalent to adding the aux loss to the final loss.