mlyg / unified-focal-loss

Apache License 2.0
150 stars 22 forks source link

softmax or no softmax? #1

Closed ranabanik closed 2 years ago

ranabanik commented 3 years ago

Hello, I tried to overfit 3D U-Net with a single data point(multi-class segmentation) with focal dice loss. When I use only the focal dice loss without any softmax activation in the network the dice score(excluding the background 0) per channel reaches up to ~93%. Not sure I can call this satisfactory(overfitting) as I haven’t checked the output. results_epoch_11

When I use a final softmax activation layer as this paper, with the same single data point and focal dice loss criterion the results are very different.

      x9 = self.layer11(x8)
      return F.softmax(x9, dim=1) # channel wise softmax

Though the loss and dice of train and validation are on par the train dice is below ~40%. results_epoch_11

def focal_dice_loss(y_pred, y_true, delta = 0.5, gamma_fd=0.75, epsilon = 1e-6):
    """
    :param delta: controls weight given to false positive and false negatives.
                  this equates to the Focal Tversky loss when delta = 0.7
    :param gamma_fd: focal parameter controls degree of down-weighting of easy examples
    """
    axis = identify_axis(y_pred.shape)
    ones = torch.ones_like(y_pred)
    p_c = y_pred      # proba that voxels are class i
    p_n = ones-y_pred
    g_t = y_true.type(torch.cuda.FloatTensor) #cuda.FloatTensor)
    g_n = ones-g_t
    tp = torch.sum(torch.sum(p_c*g_t, axis), 0)
    fp = torch.sum(torch.sum(p_c*g_n, axis), 0)
    fn = torch.sum(torch.sum(p_n*g_t, axis), 0)
    tversky_dice = torch.clamp((tp+epsilon)/(tp + delta*fn + (1-delta)*fp + epsilon), epsilon, 1 - epsilon) #torch.Size([9])
    focal_dice_loss_fg = torch.pow((1-tversky_dice), gamma_fd)[1:] # removing 0 --> background #todo: rmove back
#     focal_dice_loss = torch.sum(torch.pow((1-focal_dice_class), gamma_fd), 0)   
#     num_classes = torch.tensor(trgt.size()[1], dtype=torch.float) # needed?
    dice_loss = torch.sum(focal_dice_loss_fg)
    focal_dice_per_class = torch.mean(focal_dice_loss_fg)
    return dice_loss, focal_dice_loss_fg, focal_dice_per_class

Anything wrong with this observation?

mlyg commented 3 years ago

Hi ranabanik! Firstly, thank you for taking interest in this loss function project.

It is difficult for me to give concrete advice because I am more familiar with Tensorflow and I see you are using Pytorch. However, one thing I am curious about is this line:

"return F.softmax(x9, dim=1) # channel wise softmax"

I wonder if the softmax is being applied to the wrong dimension? Please take a look at this thread: https://github.com/pytorch/pytorch/issues/1020 Here they say: "For a 3D input, the softmax takes place over dimension 0." I wonder if the solution is to change the code to:

"return F.softmax(x9, dim=0) # channel wise softmax"

Good luck with your experiments!

P.S. the repository has been updated since to reflect additional results. The 'Focal Dice loss' has been renamed to 'Focal Tversky loss' to keep naming conventions consistent - the code should be very similar if not identical for this loss function.

ranabanik commented 3 years ago

Hello @mlyg Thanks for replying. The dimension I am passing through the softmax layer is batches x channels/classes x H x W x D. I think softmax is performed channel/class-wise in multi-class segmentation.
If I refer to section 2.1 of this paper here

The network outputs a softmax quasi-probability map Ps(x)for each segmentation class s ∈ S.

In my case, the dim = 1 is channels of different classes.

mlyg commented 3 years ago

That makes sense, dim = 1 is the correct one.

I wonder if the issue could be related to the data you are using? Could you perhaps explain what you mean by 'a single datapoint'? Specifically, I was wondering how the labels are encoded, because that might affect whether it is appropriate to use a softmax function.

ranabanik commented 3 years ago

a single data point is a patch of size 48,64,64 with labels one-hot encoded and having a final shape of 2,9,48,64,64 corresponding to <batch size, number of channels, H, W, D>. I can share one piece of data if you have an email to send for you to check.