gokulprasadthekkel / pytorch-multi-class-focal-loss

MIT License
51 stars 9 forks source link

Implementation is incorrect #1

Open ctensmeyer opened 3 years ago

ctensmeyer commented 3 years ago
ce_loss = F.cross_entropy(input, target,reduction=self.reduction,weight=self.weight)
pt = torch.exp(-ce_loss)
focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()

While ce_loss is correctly -weight[y_t]log(p[y_t]), pt is not p[y_t] as you would expect. Instead, it is e^(weight[y_t])p[y_t], which is incorrect.

Also, the reduction is performed at the CE step so pt isn't a tensor of the probabilities of individual spatial positions, it's a scalar.

pytholic commented 2 years ago
ce_loss = F.cross_entropy(input, target,reduction=self.reduction,weight=self.weight)
pt = torch.exp(-ce_loss)
focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()

While ce_loss is correctly -weight[y_t]log(p[y_t]), pt is not p[y_t] as you would expect. Instead, it is e^(weight[y_t])p[y_t], which is incorrect.

Also, the reduction is performed at the CE step so pt isn't a tensor of the probabilities of individual spatial positions, it's a scalar.

Hi @ctensmeyer , can you share the correct implementation code?

amr-lopezjos commented 2 weeks ago

alphat = self.weight[target] focal_loss = (alphat * ((1 - pt) * self.gamma ce_loss)).mean()