mlcommons / training

Reference implementations of MLPerf™ training benchmarks
https://mlcommons.org/en/groups/training
Apache License 2.0
1.57k stars 548 forks source link

[SSD] Question about background class #547

Closed fzqneo closed 2 years ago

fzqneo commented 2 years ago

The doc says the MLPerf dataset has 264 "real" classes. Therefore, as far as I understand, the classification head should have 264+1 = 265 logits, with class_id=0 for background. But the PyTorch model outputs only 264-d logits. Am I missing something?

ahmadki commented 2 years ago

Hi @fzqneo

You are right about the background class (n+1) for vanilla SSD. However, the current implementation (RetinaNet) uses focal loss, which does binary cross entropy per class per anchor box. There is no background class.

fzqneo commented 2 years ago

@ahmadki thanks for the quick reply.

Interesting. This is indeed different than canonical SSD and different than what I expected. So the model's target, if a box is positive (i.e., successfully matched to an annotation by the matcher), it will have a 264-d label with a single 1 and all others 0; if a box is negative, it will have 264-d of all 0s. Am I right?

Would you mind pointing me to the code which does this preprocessing?

ahmadki commented 2 years ago

That is correct.

The model loss function is called here. It takes three params.

  1. targets (or ground truth), is a list of dicts in the length of batch_size, each dict holds (among other things):

    1. bboxes: ground truth boxes ([num_bbox_per_image, 4])
    2. labels: labels or class id for each bbox ([num_bbox_per_image]). Note the labels are not one-hot encoded yet.
  2. head_outputs (or network output). A dict with:

    1. cls_logits: classification results ([batch_size, number_of_anchors, num_of_classes])
    2. bbox_regression: bbox corrections for each anchor ([batch_size, number_of_anchors, 4])
  3. anchors, the priori boxes ([number_of_anchors, 4])

(of course for our model you can assume number_of_anchors=120087, number_of_classes=264)

Inside the loss function, the proposal matcher builds matched_idxs here. A tensor of shape [batch_size, number_of_anchors]. It has negative values for invalid anchors and is >=0 for valid ones.

matched_idxs is later used to filter out unwanted bboxes here and here.

The bbox classifications (labels) are one-hot encoded here before the actual call to the focal loss function here

fzqneo commented 2 years ago

Thank you for the patient and detailed response! It clears most of my confusion.