IrvingMeng / MagFace

MagFace: A Universal Representation for Face Recognition and Quality Assessment, CVPR2021, Oral
Apache License 2.0
626 stars 85 forks source link

每个stage后面一些epoch的loss一直震荡? #43

Closed changhy666 closed 2 years ago

changhy666 commented 2 years ago

image 问题:

  1. 每个stage后面一些epoch的loss一直震荡,precision指标比同期的arcface+focal要低,请问可能是啥原因呢?
  2. 在30w数据中,训练出来的幅度范围只有:1~5,为啥这么小?而且跟质量不是特别正相关? 环境:总共60个epoch,4个stage,学习率0.1倍衰减[12, 24, 48],初始学习率0.1,数据集:deepglint,超参数就是默认的没改; 代码重新写了下:

    
    class MagFace(nn.Module):
    def __init__(self, in_features, out_features, device_id=None, s = 64.0, l_margin = 0.45, u_margin = 0.8, l_a = 10, u_a = 110, easy_margin = False, fp16 = False):
        super(MagFace, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.device_id = device_id
    
        self.s = s
        self.l_margin = l_margin
        self.u_margin = u_margin
        self.l_a = l_a
        self.u_a = u_a
        self.easy_margin = easy_margin
        self.fp16 = fp16
    
        self.weight = Parameter(torch.FloatTensor(out_features, in_features))
        self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
        # nn.init.xavier_uniform_(self.weight)
    
    def _margin(self, x):
        """generate adaptive margin
        """
        margin = (self.u_margin - self.l_margin) / \
                 (self.u_a - self.l_a) * (x - self.l_a) + self.l_margin
        return margin
    
    def forward(self, x, target):
        """
        Here m is a function which generate adaptive margin
        """
        # norm the weight
        if self.device_id == None:
            cos_theta = F.linear(F.normalize(x), F.normalize(self.weight))
        else:
            sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
            temp_x = x.cuda(self.device_id[0])
            weight = sub_weights[0].cuda(self.device_id[0])
            cos_theta = F.linear(F.normalize(temp_x), F.normalize(weight))
            for i in range(1, len(self.device_id)):
                temp_x = x.cuda(self.device_id[i])
                weight = sub_weights[i].cuda(self.device_id[i])
                cos_theta = torch.cat((cos_theta, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])), dim=1)
    
        x_norm = torch.norm(x, dim=1, keepdim=True).clamp(self.l_a, self.u_a)
        ada_margin = self._margin(x_norm)
        cos_m, sin_m = torch.cos(ada_margin), torch.sin(ada_margin)
    
        cos_theta = cos_theta.clamp(-1, 1)
        sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2))
        cos_theta_m = cos_theta * cos_m - sin_theta * sin_m
    
        if self.fp16:
            cos_theta_m = cos_theta_m.half()
        if self.easy_margin:
            cos_theta_m = torch.where(cos_theta > 0, cos_theta_m, cos_theta)
        else:
            mm = torch.sin(math.pi - ada_margin) * ada_margin
            if self.fp16:
                mm = mm.half()
            threshold = torch.cos(math.pi - ada_margin)
            cos_theta_m = torch.where(cos_theta > threshold, cos_theta_m, cos_theta - mm)
    
        if self.device_id != None:
            cos_theta = cos_theta.cuda(self.device_id[0])
        one_hot = torch.zeros_like(cos_theta)
        one_hot.scatter_(1, target.view(-1, 1), 1.0)
        output = one_hot * cos_theta_m + (1.0 - one_hot) * cos_theta
        output *= self.s
    
        return output

class MagLoss(torch.nn.Module): """ MagFace Loss. """

def __init__(self, l_margin = 0.45, u_margin = 0.8, l_a = 10, u_a = 110, s = 60, lambda_g = 35):
    super(MagLoss, self).__init__()
    self.s = s
    self.l_margin = l_margin
    self.u_margin = u_margin
    self.l_a = l_a
    self.u_a = u_a
    self.lambda_g = lambda_g

    k = (self.u_margin - self.l_margin) / (self.u_a - self.l_a)
    min_lambda = self.s * k * self.u_a ** 2 * self.l_a ** 2 / (self.u_a ** 2 - self.l_a ** 2)
    print('min lambda g is {}, currrent lambda is {}'.format(min_lambda, self.lambda_g))

def calc_loss_G(self, x_norm):
    g = 1/(self.u_a**2) * x_norm + 1/(x_norm)
    return torch.mean(g)

def forward(self, input, target):
    x_norm = torch.norm(input, dim=1, keepdim=True).clamp(self.l_a, self.u_a)
    loss_g = self.calc_loss_G(x_norm)
    loss = F.cross_entropy(input, target, reduction='mean')
    return loss.mean() + self.lambda_g * loss_g
changhy666 commented 2 years ago

解决了,代码改的有问题,loss_g一直被钳位在110了。。