Open BinhuiXie opened 2 years ago
Yes, the extra one dimension does represent the "no object".
thanks!
Sorry, another question.
Why is sigmoid_focal_loss
(binary cross entropy with logits) usually used in object detection? What are the advantages? Could we use standard cross-entropy with softmax?
In my understanding, sigmoid is more suitable for multi-class classification. When the model is not sure which of two classes an object belongs to, it can predict both so that one of them is correct.
That makes sense.
In fact, I tried softmax_focal_loss
following sigmoid_focal_loss
as follows:
def softmax_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
"""
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: An integer tensor. Stores the class label for each element in inputs.
alpha: (optional) Weighting factor in range (0,1) to balance.
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
Returns:
Loss tensor
"""
_EPSILON = 1e-4
prob = F.softmax(inputs, dim=2)
prob = prob.gather(-1, targets.unsqueeze(-1))
logpt = torch.log(torch.clamp(prob, _EPSILON, 1 - _EPSILON))
focal_modulation = (1 - prob) ** gamma
loss = -alpha * focal_modulation * logpt
return loss.mean(1).sum() / num_boxes
And then modify the following lines https://github.com/IDEACVR/DINO/blob/67bbcd97ef30a48cf343b7b0f3ad9ea0795b6fcd/models/dino/dino.py#L383-L385
target_classes_onehot = target_classes_onehot[:,:,:-1]
loss_ce = softmax_focal_loss(src_logits, target_classes_onehot.argmax(-1), num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1]
losses = {'loss_ce': loss_ce}
However, the performance drops considerably.
Could you give some bits of advice? Thanks!
Hello, guys. Well done!
I have a quick question about
sigmoid_focal_loss
as follows: https://github.com/IDEACVR/DINO/blob/67bbcd97ef30a48cf343b7b0f3ad9ea0795b6fcd/models/dino/dino.py#L379Why is the third dimension of
target_classes_onehot
one more than that ofsrc_logits
? Does the extra one dimension represent the "no object"?Thanks in advance.