MIC-DKFZ / medicaldetectiontoolkit

The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dealing with medical images.
Apache License 2.0
1.31k stars 297 forks source link

Bug in compute_bbox_loss of retina_net.py #59

Closed oetzus closed 5 years ago

oetzus commented 5 years ago

Hi Paul, it's me again! In retina_net.py you calculate the loss of the classifcation using pos_indices = torch.nonzero(anchor_matches > 0) for the positive samples. However, for the bbox loss it's indices = torch.nonzero(anchor_matches == 1).squeeze(1), so it the loss is only calculated for the first class. I suppose it's a bug, or did I miss something?

pfjaeger commented 5 years ago

Hi, as you can see in line 128, anchor_matches is a flag that only takes the values -1, 0, 1 for positive netural and negative anchors. This value does not represent class_id. Since it only takes those values, >0 and ==1 are the same thing.

oetzus commented 5 years ago

Hi, as far as I understood, it really does represent the class_id, since it calculates the loss with target_pos, which is the positive subset of anchor_matches (line 143-144). I also printed the unique values of anchor_matches and it really contains values as high as the number of classes.

pfjaeger commented 5 years ago

Hi, apologies, you are right. I will fix this right away. Thank you so much for catching this, and for insisting after my incorrect first reply!