scaelles / DEXTR-PyTorch

Deep Extreme Cut http://www.vision.ee.ethz.ch/~cvlsegmentation/dextr
GNU General Public License v3.0
843 stars 153 forks source link

Question about Loss #36

Closed mt-cly closed 2 years ago

mt-cly commented 2 years ago

Hi, thanks for your great work. I have some confusion about the class-balance cross-entropy loss at https://github.com/scaelles/DEXTR-PyTorch/blob/master/layers/loss.py#L26 . I notice that the output_gt_zero will differ the effects of postive prediction and negative prediction, I do not know why. In my thought, the loss_val should be equal to log(sigmoid(output)) when label==1, while log(1-sigmoid(output)) when label==0. However, this does not match with your code, can you please explain it or give me some paper references. Thank you.

scaelles commented 2 years ago

Hello,

Thanks for your interest! Since we are predicting masks after cropping the image using the extreme points, there will be an imbalance on the number of foreground and background pixels. There will be way more foreground pixels than background ones which can bias the network on being overconfident about foreground. In order to alleviate that problem, we use the class balancing that was originally introduced in [1].

[1]: Holistically-Nested Edge Detection

mt-cly commented 2 years ago

Hi, thanks for your reply. I read the reference paper, the paper introduces a bias weight to two terms in the formula, which corresponds to your num_labels_neg/num_total and num_labels_pos/num_total. But it still exists a gap with your implementation. I am not so clear about the loss_val calculation. following is copied from your code.

output_gt_zero = torch.ge(output, 0).float()
loss_val = torch.mul(output, (labels - output_gt_zero)) - torch.log(1 + torch.exp(output - 2 * torch.mul(output, output_gt_zero)))

I wonder why it is not as follows: loss_val = torch.mul(label, torch.log(1/(1+torch.exp(-output)))) + torch.mul(1.-label, torch.log(1-1/(1+torch.exp(-output)))) Thank you.

scaelles commented 2 years ago

Hello,

The loss that you describe would be the theoretical definition (you can find it here in one of our previous projects). However, that can be unstable during training and the formulation is rearranged so it behaves better during training. You can find here a similar derivation from the theoretical to the practical one (it's not exactly the same, but you should be able to derive ours from there).

Also, note also that the actual class balancing happens here.

Let me know if you have any other question!