shruti-jadon / Semantic-Segmentation-Loss-Functions

This Repository is implementation of majority of Semantic Segmentation Loss Functions
MIT License
543 stars 84 forks source link

Focal Tversky Loss implementaiton #2

Open YoonSungLee opened 3 years ago

YoonSungLee commented 3 years ago

Hi, thank you for your useful repo about a lot of loss functions. I had a problem about class imbalance problem in my project, but solved this problem by using this repo.

But, when using this repo, I was not able to use 'focal_tversky' loss function in this repo. Whenever I use this code, I got an error which means 'validation loss is nan'.

What can I do to use this loss function in my project? Here's my code when I tried to use this function. I changed tensorflow to pytorch.

# focal tversky loss
class FocalTverskyLoss(nn.Module):
    def __init__(self, weight=None, size_average=True, smooth=1, alpha=0.7, gamma=0.75):
        super(FocalTverskyLoss, self).__init__()
        self.smooth = smooth
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        pt_1 = self.tversky_index(inputs, targets)
        return torch.pow((1 - pt_1), self.gamma)

    def tversky_index(self, inputs, targets):
        y_true_pos = torch.flatten(targets)
        y_pred_pos = torch.flatten(inputs)
        true_pos = torch.sum(y_true_pos * y_pred_pos)
        false_neg = torch.sum(y_true_pos * (1 - y_pred_pos))
        false_pos = torch.sum((1 - y_true_pos) * y_pred_pos)
        return (true_pos + self.smooth) / (true_pos + self.alpha * false_neg + (
                    1 - self.alpha) * false_pos + self.smooth)

Thank you.

raimannma commented 1 year ago

The problem is that gamma is 0.75.

and if (1 - pt_1) in the forward method has negative values the power to a value less than 1 is not defined.

Cause you can't take the square root of a negative number (in real number world)

In paper they say: γ can range from [1,3]

I think it is a bug in the code.