cvqluu / Angular-Penalty-Softmax-Losses-Pytorch

Angular penalty loss functions in Pytorch (ArcFace, SphereFace, Additive Margin, CosFace)
MIT License
481 stars 92 forks source link

Better loss implementation, fully GPU support #7

Closed laiviet closed 4 years ago

laiviet commented 4 years ago

This is my cleaner, high performance, fully GPU support for the loss function. This implementation does not require torch.cat (which is painfully slow). I exclude the FC layer and the normalization layer for my project.

import math
import torch

class AMSLoss(nn.Module):

    def __init__(self, m=1.0):
        super(AMSLoss, self).__init__()
        self.m = m
        self.one_minus_exp_m = 1.0 - math.exp(m)

    def forward(self, logits, labels, eps=1e-10):
        """

        :param logits: B x C
        :param targets: B
        :return:
        """
        numerator = torch.diagonal(logits.transpose(0, 1)[labels])
        denominator = torch.sum(torch.exp(logits + self.m), dim=1) + torch.exp(numerator) * self.one_minus_exp_m
        L = numerator - torch.log(denominator + eps)
        return -torch.mean(L)
jeannotes commented 3 years ago

seems that in numerator, there is no scale and margin minus @laiviet @cvqluu