amirhfarzaneh / dacl

Deep Attentive Center Loss
61 stars 14 forks source link

Can the SparseCenterLoss be implemented in another way? #11

Open cotyyang opened 1 year ago

cotyyang commented 1 year ago

hello,your code loss.py and center_loss_pytorch are similar,but center loss has another implement,https://github.com/KaiyangZhou/pytorch-center-loss/issues/20 ,therefore, can the scheme be modified as follows?

class CenterLoss(nn.Module):
    def __init__(self, num_class=10, num_feature=2):
        super(CenterLoss, self).__init__()
        self.num_class = num_class
        self.num_feature = num_feature
        self.centers = nn.Parameter(torch.randn(self.num_class, self.num_feature))

    def forward(self, feat, A, label):
        center = self.centers[label]
        dist = (A * ((feat - center).pow(2))).sum(dim=-1)
        loss = torch.clamp(dist, min=1e-12, max=1e+12).mean(dim=-1)

        return loss