IDEA-Research / DINO

[ICLR 2023] Official implementation of the paper "DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection"
Apache License 2.0
2.15k stars 232 forks source link

About sigmoid_focal_loss #36

Open BinhuiXie opened 2 years ago

BinhuiXie commented 2 years ago

Hello, guys. Well done!

I have a quick question about sigmoid_focal_loss as follows: https://github.com/IDEACVR/DINO/blob/67bbcd97ef30a48cf343b7b0f3ad9ea0795b6fcd/models/dino/dino.py#L379

Why is the third dimension of target_classes_onehot one more than that of src_logits? Does the extra one dimension represent the "no object"?

Thanks in advance.

HaoZhang534 commented 2 years ago

Yes, the extra one dimension does represent the "no object".

BinhuiXie commented 2 years ago

thanks!

BinhuiXie commented 2 years ago

Sorry, another question.

https://github.com/IDEACVR/DINO/blob/67bbcd97ef30a48cf343b7b0f3ad9ea0795b6fcd/models/dino/dino.py#L384

Why is sigmoid_focal_loss (binary cross entropy with logits) usually used in object detection? What are the advantages? Could we use standard cross-entropy with softmax?

HaoZhang534 commented 2 years ago

In my understanding, sigmoid is more suitable for multi-class classification. When the model is not sure which of two classes an object belongs to, it can predict both so that one of them is correct.

BinhuiXie commented 2 years ago

That makes sense.

In fact, I tried softmax_focal_loss following sigmoid_focal_loss as follows:

def softmax_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
    """
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: An integer tensor. Stores the class label for each element in inputs.
        alpha: (optional) Weighting factor in range (0,1) to balance.
        gamma: Exponent of the modulating factor (1 - p_t) to
               balance easy vs hard examples.
    Returns:
        Loss tensor
    """
    _EPSILON = 1e-4

    prob = F.softmax(inputs, dim=2)
    prob = prob.gather(-1, targets.unsqueeze(-1))  

    logpt = torch.log(torch.clamp(prob, _EPSILON, 1 - _EPSILON))
    focal_modulation = (1 - prob) ** gamma

    loss = -alpha * focal_modulation * logpt

    return loss.mean(1).sum() / num_boxes

And then modify the following lines https://github.com/IDEACVR/DINO/blob/67bbcd97ef30a48cf343b7b0f3ad9ea0795b6fcd/models/dino/dino.py#L383-L385

    target_classes_onehot = target_classes_onehot[:,:,:-1] 
    loss_ce = softmax_focal_loss(src_logits, target_classes_onehot.argmax(-1), num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] 
    losses = {'loss_ce': loss_ce} 

However, the performance drops considerably.

Could you give some bits of advice? Thanks!