This is my cleaner, high performance, fully GPU support for the loss function. This implementation does not require torch.cat (which is painfully slow).
I exclude the FC layer and the normalization layer for my project.
import math
import torch
class AMSLoss(nn.Module):
def __init__(self, m=1.0):
super(AMSLoss, self).__init__()
self.m = m
self.one_minus_exp_m = 1.0 - math.exp(m)
def forward(self, logits, labels, eps=1e-10):
"""
:param logits: B x C
:param targets: B
:return:
"""
numerator = torch.diagonal(logits.transpose(0, 1)[labels])
denominator = torch.sum(torch.exp(logits + self.m), dim=1) + torch.exp(numerator) * self.one_minus_exp_m
L = numerator - torch.log(denominator + eps)
return -torch.mean(L)
This is my cleaner, high performance, fully GPU support for the loss function. This implementation does not require torch.cat (which is painfully slow). I exclude the FC layer and the normalization layer for my project.