supernotman / RetinaFace_Pytorch

Reimplement RetinaFace with Pytorch
305 stars 68 forks source link

focal_loss = False #8

Closed you-old closed 5 years ago

you-old commented 5 years ago

focal_loss = False

focal loss

        if focal_loss:
            alpha = 0.25
            gamma = 2.0            
            alpha_factor = torch.ones(targets.shape).cuda() * alpha

            alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor)
            focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)
            focal_weight = alpha_factor * torch.pow(focal_weight, gamma)

            bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification))

            cls_loss = focal_weight * bce

            cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape).cuda())

            classification_losses.append(cls_loss.sum()/torch.clamp(num_positive_anchors.float(), min=1.0))
        else:
            if positive_indices.sum() > 0:
                classification_losses.append(positive_losses.mean() + sorted_losses.mean())
            else:
                classification_losses.append(torch.tensor(0).float().cuda())

never use focalloss???

luyao777 commented 5 years ago

Maybe there is no need to use focal loss

supernotman commented 5 years ago

focal_loss = False

focal loss

if focal_loss: alpha = 0.25 gamma = 2.0 alpha_factor = torch.ones(targets.shape).cuda() * alpha

            alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor)
            focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)
            focal_weight = alpha_factor * torch.pow(focal_weight, gamma)

            bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification))

            cls_loss = focal_weight * bce

            cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape).cuda())

            classification_losses.append(cls_loss.sum()/torch.clamp(num_positive_anchors.float(), min=1.0))
        else:
            if positive_indices.sum() > 0:
                classification_losses.append(positive_losses.mean() + sorted_losses.mean())
            else:
                classification_losses.append(torch.tensor(0).float().cuda())

never use focalloss???

I tried focalloss but failed.The model even can't converge while training.

supernotman commented 5 years ago

focal_loss = False

focal loss

if focal_loss: alpha = 0.25 gamma = 2.0 alpha_factor = torch.ones(targets.shape).cuda() * alpha

            alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor)
            focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)
            focal_weight = alpha_factor * torch.pow(focal_weight, gamma)

            bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification))

            cls_loss = focal_weight * bce

            cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape).cuda())

            classification_losses.append(cls_loss.sum()/torch.clamp(num_positive_anchors.float(), min=1.0))
        else:
            if positive_indices.sum() > 0:
                classification_losses.append(positive_losses.mean() + sorted_losses.mean())
            else:
                classification_losses.append(torch.tensor(0).float().cuda())

never use focalloss???

I tried focalloss but failed.The model even can't converge while training.

I guess some special initialization for model are needed like in retinaNet.(just a guess) You can have a try of focalloss,maybe i got something wrong.

you-old commented 5 years ago

hi , you even not use crossentrpy, only .mean+.mean, why ? is this work ?

supernotman commented 5 years ago

hi , you even not use crossentrpy, only .mean+.mean, why ? is this work ?

Actually,i used binary crossentropy loss in classification.In Pytorch, the ce loss function nn.CrossEntropy actually is a combination of nn.logsoftmax and nn.nllloss(). In my code, i just use nn.logsoftmax() and achieve nn.nllloss() in another way. About the loss details , a blog for reference:https://blog.csdn.net/watermelon1123/article/details/91044856

you-old commented 5 years ago

understand, 3Q, author ,you are very smart, . can you upload your pretrained model???