fundamentalvision / Deformable-DETR

Deformable DETR: Deformable Transformers for End-to-End Object Detection.
Apache License 2.0
3.13k stars 511 forks source link

Questions about computing focal loss #172

Open the-yanqi opened 1 year ago

the-yanqi commented 1 year ago

https://github.com/fundamentalvision/Deformable-DETR/blob/11169a60c33333af00a4849f1808023eba96a931/models/segmentation.py#L221

The focal loss is computed on all object queries (i.e. suppose num_queries=300, all 300 queries have a focal loss). However, the focal loss is only divided by num_boxes, which is the number of all the ground truth boxes in this batch and this number is significantly smaller than the number of all object queries.

Do you have any specific reasons for computing the focal loss in this way?

Thanks

EricWiener commented 1 year ago

The focal loss implementation comes from the original DETR paper where it was applied to masks for the panoptic segmentation extension of DETR. From the DETR repo:

All losses are normalized by the number of objects inside the batch. Source.

The label loss in Deformable DETR is calculated with:

loss_ce = (
    sigmoid_focal_loss(source_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2)
    * source_logits.shape[1]
)

sigmoid_focal_loss will first compute loss.mean(1).sum() which sums up the values across all queries and then divides by the number of queries. It then divides by the number of boxes to normalize by the number of GT objects. Finally, the loss returned by sigmoid_focal_loss is multiplied by source_logits.shape[1] (the number of queries) which effectively counteracts the initial division by the number of queries that was performed via loss.mean(1). This makes it so the class loss is only normalized by the number of GT boxes and not by the number of queries.