clcarwin / sphereface_pytorch

A PyTorch Implementation of SphereFace.
MIT License
714 stars 172 forks source link

when CrossEntropyLoss loss=nan #24

Closed JonesonZheng closed 6 years ago

JonesonZheng commented 6 years ago

when I use

    loss = nn.CrossEntropyLoss()(output, target) 

replace the code in AngleLoss

    logpt = F.log_softmax(output)
    logpt = logpt.gather(1,target)
    logpt = logpt.view(-1)
    pt = Variable(logpt.data.exp())
    loss = -1 * (1-pt)**self.gamma * logpt
    loss = loss.mean()

I got loss=nan

I think theoretically logpt = F.log_softmax(output) here should be logpt = F.log_softmax(output,1) which is just the same as CrossEntropyLoss, but you implement is logpt = F.log_softmax(output,0) Would you help explain it?

clcarwin commented 6 years ago

You can try to compare the code in focal_loss_pytorch.

suanrong commented 5 years ago

@JonesonZheng

I think you are right. It should be logpt = F.log_softmax(output,1)