Open plun11 opened 12 months ago
I agree with this comment, and also note that using the .view(-1) method throws an error for me when my label tensor (y_true) is of shape (batch_size,H,W). I used a surrogate Dice Loss function: ` def dice_loss(y_pred, y_true): """ Compute the Dice loss between predictions and true labels.
Args:
y_pred (torch.Tensor): The predicted class probabilities with shape (batch_size, num_classes, H, W).
y_true (torch.Tensor): The true labels with shape (batch_size, H, W).
Returns:
torch.Tensor: Scalar Dice loss.
"""
# Convert y_true to one-hot format to match the shape of y_pred
y_true = y_true.long()
y_true_one_hot = torch.nn.functional.one_hot(y_true, num_classes=y_pred.shape[1]).permute(0, 3, 1, 2).float()
# Apply softmax to y_pred to get class probabilities
y_pred = torch.nn.functional.softmax(y_pred, dim=1)
# Compute intersection and union for Dice score
intersection = torch.sum(y_pred * y_true_one_hot, dim=(2, 3))
union = torch.sum(y_pred, dim=(2, 3)) + torch.sum(y_true_one_hot, dim=(2, 3))
# Compute Dice loss
dice_score = 2.0 * intersection / (union + 1e-6) # Add a small epsilon to avoid division by zero
dice_loss = 1 - dice_score # Dice loss is 1 minus the Dice score
# Return the mean Dice loss over all classes and batch
return torch.mean(dice_loss)
` which avoids the error.
Hello, thank you for the great library.
There is a line in the DiceLoss class:
if not self.from_logits: y_pred = F.sigmoid(y_pred)
I am not sure, but from the description, I think it is meant to apply the sigmoid if the self.from_logits is True, so there should be just "if", instead of "if not".