google-research / scenic

Scenic: A Jax Library for Computer Vision Research and Beyond
Apache License 2.0
3.23k stars 426 forks source link

Focal loss in OWL-ViT #950

Open sargun-nagpal opened 10 months ago

sargun-nagpal commented 10 months ago

Hi, I was going through the loss script for OWL-ViT and wanted to confirm the implementation of the focal loss for training/fine-tuning the model.

From the focal loss paper, $$FL(p_t) = -\alpha_t(1-p_t)^\gamma log(p_t)$$ When y = 1: $$FL(p) = -\alpha(1-p)^\gamma log(p)$$ When y = 0: $$FL(1-p) = -(1-\alpha) p^\gamma log(1-p)$$ $$\therefore FL = -[y\alpha(1-p)^\gamma log(p) + (1-y)(1-\alpha) p^\gamma log(1-p)]$$.

However, in the implementation, I see that the cost is computed as: $$Cost = -\alpha(1-p)^\gamma log(p) + (1-\alpha) p^\gamma log(1-p)]$$.

This is not the same as the formula above. Can someone please explain why we are calculating the loss this way, or if I am misunderstanding something?

hvgazula commented 9 months ago

@sargun-nagpal Did you notice *= (Multiply AND) next to neg_cost_loss as well as pos_cost_loss?

sargun-nagpal commented 9 months ago

@hvgazula Yes, I did. That just calculates the following:

pos_cost_class $= -\alpha(1-p)^\gamma log(p)$ neg_cost_class $= -(1-\alpha) p^\gamma log(1-p)]$ Therefore, pos_cost_class - neg_cost_class $= -\alpha(1-p)^\gamma log(p) + (1-\alpha) p^\gamma log(1-p)]$.

This is in contrast to the focal loss formula (mentioned above), where we make use of the ground truth label y to choose one of pos_cost_loss or neg_cost_loss terms to calculate the loss: $$FL = -[y\alpha(1-p)^\gamma log(p) + (1-y)(1-\alpha) p^\gamma log(1-p)]$$.

hvgazula commented 9 months ago

Hello! Sorry for being unclear earlier. In fact, you derived the answer yourself 😉 . All you need to tell yourself is- In the equation from the article, t is the ground truth, and (in binary classification) it has two possibilities pos class and neg class. Now write down the cost for each sample (based on whether t = pos or t = neg) and you have the equation in your comment.

In other words- Imagine you have 2 samples (1 positive [t = pos] and 1 negative [t = neg]). Write down the cost for the positive sample as well as the negative sample and those are the two terms in your derivation.

hvgazula commented 9 months ago

more succintly FL(all samples) = FL(pos samples) + FL(neg samples) ...

sargun-nagpal commented 9 months ago

Hi @hvgazula! Thank you for your reply.

I believe it should be: FL(all samples) = y * FL(pos samples) + (1-y) * FL(neg samples)

However, in the code, they use: FL(all samples) = FL(pos samples) - FL(neg samples)

hvgazula commented 9 months ago

pos samples itself means y = 1. So, y * FL(pos samples) again is redundant.

Regarding why FL (all samples) = FL (pos samples) - FL(neg samples), Section 2.1 from this paper as pointed in https://github.com/google-research/scenic/blob/1963df79aad2fa2bc5fb2184dd9bbc8761e27e84/scenic/projects/owl_vit/losses.py#L21 should clear the confusion.