mberkay0 / pretrained-backbones-unet

A PyTorch-based Python library with UNet architecture and multiple backbones for Image Semantic Segmentation.
MIT License
40 stars 9 forks source link

DiceLoss question in pretrained-backbones-unet/backbones_unet/model /losses.py #7

Open plun11 opened 7 months ago

plun11 commented 7 months ago

Hello, thank you for the great library.

There is a line in the DiceLoss class:

if not self.from_logits: y_pred = F.sigmoid(y_pred)

I am not sure, but from the description, I think it is meant to apply the sigmoid if the self.from_logits is True, so there should be just "if", instead of "if not".

physgorg commented 4 months ago

I agree with this comment, and also note that using the .view(-1) method throws an error for me when my label tensor (y_true) is of shape (batch_size,H,W). I used a surrogate Dice Loss function: ` def dice_loss(y_pred, y_true): """ Compute the Dice loss between predictions and true labels.

Args:
    y_pred (torch.Tensor): The predicted class probabilities with shape (batch_size, num_classes, H, W).
    y_true (torch.Tensor): The true labels with shape (batch_size, H, W).

Returns:
    torch.Tensor: Scalar Dice loss.
"""
# Convert y_true to one-hot format to match the shape of y_pred
y_true = y_true.long()
y_true_one_hot = torch.nn.functional.one_hot(y_true, num_classes=y_pred.shape[1]).permute(0, 3, 1, 2).float()

# Apply softmax to y_pred to get class probabilities
y_pred = torch.nn.functional.softmax(y_pred, dim=1)

# Compute intersection and union for Dice score
intersection = torch.sum(y_pred * y_true_one_hot, dim=(2, 3))
union = torch.sum(y_pred, dim=(2, 3)) + torch.sum(y_true_one_hot, dim=(2, 3))

# Compute Dice loss
dice_score = 2.0 * intersection / (union + 1e-6)  # Add a small epsilon to avoid division by zero
dice_loss = 1 - dice_score  # Dice loss is 1 minus the Dice score

# Return the mean Dice loss over all classes and batch
return torch.mean(dice_loss)

` which avoids the error.