facebookresearch / detr

End-to-End Object Detection with Transformers
Apache License 2.0
13.38k stars 2.41k forks source link

How to use different losses #259

Open ririya opened 3 years ago

ririya commented 3 years ago

Hello all,

I've been trying to play around with modifying the loss function for solving class imbalance. I tried applying weights to the cross entropy loss and also replacing the cross entropy loss with the focal loss. I'm only doing bbox detection, no segmentation.

Whatever modification I make is disrupting my performance, the MAP is dropping 40%, and some classes don't train at all.

For the weights I modified the cross entropy call and replaced self.empty_weights with another tensor self.class_weights with len = number of classes and the weights sum to number of classes, so the overall loss contribution is the same as before. I'm also keeping the last element as self.eos_coef = 0.1.

loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.class_weights)

For the focal loss i'm using this function instead of the cross entropy call:

    def focal_loss(self, outputs, targets, gamma: float = 2):
        src_logits = outputs['pred_logits']
        ce_loss = F.cross_entropy(src_logits.transpose(1, 2), targets, weight=self.class_weights, reduction = 'none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** gamma * ce_loss).mean()
        return focal_loss

Is there something i'm missing, or are any of my implementations incorrect?

alcinos commented 3 years ago

Hi @ririya

I don't understand what your self.class_weights looks like, from your description it sounds like it's exactly how we define it. Could you share the definition?

In general, note that there is a subtle interaction between the loss and the matching cost, so modifying one without modifying the other can lead to unexpected results.

Best of luck

ririya commented 3 years ago

Hi @alcinos thanks for replying.

My class_weights is a tensor exactly like the empty_weight tensor, I just changed the values of the weights.

I can see how a class_weights change would affect the matching cost, however i don't understand why this would be such a big deal in the case of the focal loss. I tried the focal loss without changing the weights and that simple change is disrupting the training.

TheDarkKnight300 commented 3 years ago

Hi @ririya, did you make any progress with your implementation of the focal loss in DETR?

ririya commented 3 years ago

@TheDarkKnight300 I put that on hold. Let me know if you figure it out.

amirhesamyazdi commented 2 years ago

Hi @ririya

I don't understand what your self.class_weights looks like, from your description it sounds like it's exactly how we define it. Could you share the definition?

In general, note that there is a subtle interaction between the loss and the matching cost, so modifying one without modifying the other can lead to unexpected results.

Best of luck

Hi @alcinos

Thank you so much for the paper.

I have relevant questions about this.

I have a problem were the background can sometimes be similar to my desired object. But the thing is I don't have enough images with the desired object in all possible backgrounds, like different times of the day and so on. So what happens is I get relatively a lot of false alarms eventhough when the object exists I get great bounding box detections. In other words the model works amazing on images containing the object but on images that don't contain the objects you would see false alarms 30~40 times out of 100 on average.

I played around with eos_coef and increased it and it helped a bit but not that much.

Do you have any other suggestions for me?

Plus, is loss_masks (in detr.py) your implementation of focal_loss? Can I change any hyperparameters to reduce false alarms without sacrificing detections? (other than eos_coed of course)

Would you suggest giving no-object images (we have plenty of those with the backgrounds that we don't want the model to detect as our object) with a phony 0 by 0 pixel bounding box?

Generally, my question boils down to how no-object class is handled and if there is any way to give the model more negative samples explicitly (not necessarily images containing desired classes). ?