MIC-DKFZ / medicaldetectiontoolkit

The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dealing with medical images.
Apache License 2.0
1.31k stars 297 forks source link

Focal loss for classification in Retina-Net #53

Closed oetzus closed 5 years ago

oetzus commented 5 years ago

Hi Paul,

i was trying to use the retina net when i looked at the implementation and noticed that you are not using the focal loss for the classification like they do in the original paper, but use SHEM instead. Is there a reason for that? Furthermore you are applying the softmax function on the logits for the negative examples, but not for the positive ones.

pfjaeger commented 5 years ago

Hi!

I tried Focal Loss, but it did not improve performance, so i assimilated all class losses to use SHEM for comparability’s sake. You are welcome to try focal loss yourself, should be a one-line modification.

As for your question about softmax: I am not sure why you ask about softmax. 1) are you confused that softmax is only explicitly called in the negative examples? This is because F.cross_entropy() takes logits and computes softmax internally, so we only need the explicit softmax for SHEM. 2) are you asking why SHEM is only applied to negative samples? This is standard procedure following the assumption that the amount of positive samples in the dataset is limited and all of them will be seen by the model sufficiently often.

oetzus commented 5 years ago

Hi Paul,

thanks for your answer! I will have a try on the focal loss, maybe it works fine for my data.

Regarding the second question, your first answer clears up what i was confused about. I missed that the softmax is only used for shem and thought that the cross entropy is applied on the outputs of softmax for the negativ samples, but not the positive ones.

Thanks for the clarification!