ZF4444 / MMAL-Net

This is a PyTorch implementation of the paper "Multi-branch and Multi-scale Attention Learning for Fine-Grained Visual Categorization (MMAL-Net)" (Fan Zhang, Meng Li, Guisheng Zhai, Yizhao Liu).
250 stars 57 forks source link

Something went wrong when changing default config #37

Open tungedng2710 opened 2 years ago

tungedng2710 commented 2 years ago

I have implemented your code and trained with "Aircraft" dataset. It worked normally until I tried to change your CE loss function into ArcFace loss and your SGD optimizer into Adam. The code still worked but I achieve log: "there is one img no intersection" and the accuracy is very low (approximately 1%). What happened?

There is my ArcFace loss

class ArcFaceLoss(nn.Module):
    def __init__(self, s=30.0, m=0.50, is_cuda=True, base_loss = 'CrossEntropyLoss'):
        super(ArcFaceLoss, self).__init__()
        self.s = s
        self.m = m
        self.criterion = nn.CrossEntropyLoss()
        self.criterion = self.criterion.cuda()

    def forward(self, input, label):
        theta = torch.acos(torch.clamp(input, -1.0 + 1e-7, 1.0 - 1e-7))
        target_logits = torch.cos(theta + self.m) 
        one_hot = torch.zeros_like(input)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        output = input * (1 - one_hot) + target_logits * one_hot
        output = output * self.s
        return self.criterion(output, label)