Open crj1998 opened 3 years ago
A concise and easy to understand version
class CenterLoss(nn.Module): def __init__(self, num_class=10, num_feature=2): super(CenterLoss, self).__init__() self.num_class = num_class self.num_feature = num_feature self.centers = nn.Parameter(torch.randn(self.num_class, self.num_feature)) def forward(self, x, labels): center = self.centers[labels] dist = (x-center).pow(2).sum(dim=-1) loss = torch.clamp(dist, min=1e-12, max=1e+12).mean(dim=-1) return loss
@crj1998 I agree. The above solution is more efficient. I suggest you open a PR for the author.
Can you please give an example of how to use this?
Does anyone know if that works?
A concise and easy to understand version