Closed fzqneo closed 2 years ago
Hi @fzqneo
You are right about the background class (n+1) for vanilla SSD. However, the current implementation (RetinaNet) uses focal loss, which does binary cross entropy per class per anchor box. There is no background class.
@ahmadki thanks for the quick reply.
Interesting. This is indeed different than canonical SSD and different than what I expected. So the model's target, if a box is positive (i.e., successfully matched to an annotation by the matcher), it will have a 264-d label with a single 1 and all others 0; if a box is negative, it will have 264-d of all 0s. Am I right?
Would you mind pointing me to the code which does this preprocessing?
That is correct.
The model loss function is called here. It takes three params.
targets (or ground truth), is a list of dicts in the length of batch_size
, each dict holds (among other things):
bboxes
: ground truth boxes ([num_bbox_per_image, 4]
)labels
: labels or class id for each bbox ([num_bbox_per_image]
). Note the labels are not one-hot encoded yet.head_outputs (or network output). A dict with:
cls_logits
: classification results ([batch_size, number_of_anchors, num_of_classes]
)bbox_regression
: bbox corrections for each anchor ([batch_size, number_of_anchors, 4]
)anchors, the priori boxes ([number_of_anchors, 4]
)
(of course for our model you can assume number_of_anchors=120087
, number_of_classes=264
)
Inside the loss function, the proposal matcher builds matched_idxs
here. A tensor of shape [batch_size, number_of_anchors]
. It has negative values for invalid anchors and is >=0
for valid ones.
matched_idxs
is later used to filter out unwanted bboxes here and here.
The bbox classifications (labels) are one-hot encoded here before the actual call to the focal loss function here
Thank you for the patient and detailed response! It clears most of my confusion.
The doc says the MLPerf dataset has 264 "real" classes. Therefore, as far as I understand, the classification head should have 264+1 = 265 logits, with class_id=0 for background. But the PyTorch model outputs only 264-d logits. Am I missing something?