KaiyangZhou / pytorch-center-loss

Pytorch implementation of Center Loss
MIT License
970 stars 219 forks source link

Doesn't anyone think the author's center loss is too complicated? #20

Open crj1998 opened 3 years ago

crj1998 commented 3 years ago

A concise and easy to understand version

class CenterLoss(nn.Module):
    def __init__(self, num_class=10, num_feature=2):
        super(CenterLoss, self).__init__()
        self.num_class = num_class
        self.num_feature = num_feature
        self.centers = nn.Parameter(torch.randn(self.num_class, self.num_feature))

    def forward(self, x, labels):
        center = self.centers[labels]
        dist = (x-center).pow(2).sum(dim=-1)
        loss = torch.clamp(dist, min=1e-12, max=1e+12).mean(dim=-1)

        return loss
ajndkr commented 3 years ago

@crj1998 I agree. The above solution is more efficient. I suggest you open a PR for the author.

RaduFilip16 commented 2 years ago

Can you please give an example of how to use this?

Anson-He commented 1 year ago

Does anyone know if that works?