faustomilletari / VNet

GNU General Public License v3.0
286 stars 122 forks source link

How to find 5 labels using Dice Loss Layer #46

Open sajjo79 opened 6 years ago

sajjo79 commented 6 years ago

Hi, I am working on BRATS data that need 5 labels. but dice loss layer outputs 2 labels, how to get 5 labels using dice loss layer. Best

gattia commented 6 years ago

You can change the network to output a 4D array that has 6 channels (one for each of the 5 labels and one for background) and do a softmax across 6 labels. You can then change the dice calculation to calculate the dice for each of the 6 labels and sum those.

I forget if VNET maximizes the loss or if they take the negative of dice and minimize that… you’ll have to deal with that appropriately.

On Nov 21, 2017, at 4:17 PM, Sajid Iqbal notifications@github.com wrote:

Hi, I am working on BRATS data that need 5 labels. but dice loss layer outputs 2 labels, how to get 5 labels using dice loss layer. Best

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/faustomilletari/VNet/issues/46, or mute the thread https://github.com/notifications/unsubscribe-auth/AKLnc3xveYo_XddTspAAoWz53ZHl2OQvks5s4z3OgaJpZM4Qmf7-.

sajjo79 commented 6 years ago

Hi, gattia, Thanks for reply. If if get 6 channels (i.e. 6 images) at the end and then apply argmax to get the maximum of each channel to get one image and then use that image to calculate Dice loss, will it be fine?

gattia commented 6 years ago

Argmax isn’t differentiable so it won’t work.

On Nov 21, 2017, at 4:37 PM, Sajid Iqbal notifications@github.com wrote:

Hi, gattia, Thanks for reply. If if get 6 channels at the end and then apply argmax to get the maximum of each channel to get one image and then use that image to calculate Dice loss, will it be fine?

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/faustomilletari/VNet/issues/46#issuecomment-346169075, or mute the thread https://github.com/notifications/unsubscribe-auth/AKLnc3vTAAiWTo4JjO43gGC-UIYarZ_jks5s40KDgaJpZM4Qmf7-.

faustomilletari commented 6 years ago

Have a look at generalized dice in niftynet.

On Nov 21, 2017, at 1:43 PM, Anthony notifications@github.com wrote:

Argmax isn’t differentiable so it won’t work.

On Nov 21, 2017, at 4:37 PM, Sajid Iqbal notifications@github.com wrote:

Hi, gattia, Thanks for reply. If if get 6 channels at the end and then apply argmax to get the maximum of each channel to get one image and then use that image to calculate Dice loss, will it be fine?

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/faustomilletari/VNet/issues/46#issuecomment-346169075, or mute the thread https://github.com/notifications/unsubscribe-auth/AKLnc3vTAAiWTo4JjO43gGC-UIYarZ_jks5s40KDgaJpZM4Qmf7-.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/faustomilletari/VNet/issues/46#issuecomment-346170921, or mute the thread https://github.com/notifications/unsubscribe-auth/AMtsvvqNKxE6hPIjO5cz6cXQSu_2uckzks5s40QagaJpZM4Qmf7-.

sajjo79 commented 6 years ago

Hi, Thanks for response. I have implemented it in caffe and code is listed below: However when i generate predictions that are binary. An output is attached below: ` class DiceLossLayer(caffe.Layer):

def forward(self, bottom, top):
    self.diff[...] = bottom[1].data
    top[0].data[...] = 1 - self.dice_coef_multi_class(bottom[0], bottom[1])

def backward(self, top, propagate_down,  bottom):
    if propagate_down[1]:
        raise Exception("label not diff")
    elif propagate_down[0]:
        a=(-2. * self.diff + self.dice) / self.sum
        bottom[0].diff[...] = a
    else:
        raise Exception("no diff")
    # =============================

def dice_coef_multi_class(self, y_pred, y_true):
    n_classes = 5
    smooth=np.float32(1e-7)
    y_true=y_true.data
    y_pred=y_pred.data
    y_pred = np.argmax(y_pred, 1)
    y_pred = np.expand_dims(y_pred,1)

    y_pred=np.ndarray.flatten(y_pred)
    y_true = np.ndarray.flatten(y_true)

    dice = np.zeros(n_classes)
    self.sum = np.zeros([n_classes])
    for i in range(n_classes):
        y_true_i = np.equal(y_true, i)
        y_pred_i = np.equal(y_pred, i)
        self.sum[i] = np.sum(y_true_i) + np.sum(y_pred_i) + smooth
        dice[i] = (2. * np.sum(y_true_i * y_pred_i) + smooth) / self.sum[i]
    self.sum=np.sum(self.sum)
    self.dice=np.sum(dice)
    return self.dice

` image

sajjo79 commented 6 years ago

Hi faustomilletari, I tried to explore NiftyNet Code but could not find their implementation of Dice Loss. Can you please guide me to exact file. Best

gattia commented 6 years ago

Im not responding about nifty net but instead about your other post.

We can't know much about the prediction you printed without a scale attached. I assume black is 0s and red and blue are different numbers completely. But that doesnt matter too much. Here are my suggestions:

For your loss. I dont use caffe, so im not sure about the forward/backward pass code you've written and if you know something I dont know about differentiating losses.... but as far as I'm aware you can't use argmax because it is not differentiable. Your loss function must be differentiable to be able to propagate the error back through the network and learn properly.

From your actual loss function, it looks like dimension 1 is the dimension with the different labels. If you want to predict all 5 of your labels, you first need to:

1) Hot-encode encode your gold standard segmentation, if they aren't already. If you have 5 labels, this will mean that the length of dimension 1 is 6, because you need a slice for each of the labels of interest, plus one for the background. 2) Need to adjust the network so that it's output shape for dimension 1 is the same as dimension 1 of the hot-encoded segmentation from step 1. 3) The activation function for the output of the network (step 2) should be a softmax function. This will mean that for each pixel the probabilities of it being each of your potential labels sums to 1.0. 4) You need a loss function (dice) that can deal with these different hot-encoded version of the segmentation.

Here is a link to a repository I just added with some code to do: 1) hot encoding, 2) softmax, 3) multi-label dice. These are rough adaptations of code I had written for a different purpose that used Keras/Tensorflow. So they might not work immediately but are pretty close. They assume you have 4D stacks of images with the first dimension being different images, and the remaining three dimensions being the 3-dimensions of the actual MRI or CT or whatever. The codes will then assume you are turning these 4d images into 5 dimensions with the second dimension (dimension 1) being the different labels.

As @faustomilletari said, there are other published ways to deal with multiple labels at a time using specific dice measures, but this is a method that I found reasonable and it worked in some testing that I did. You could advance it by adding weights that are associated with each label. Higher weights would be associated with more important labels in the segmentation process.

sajjo79 commented 6 years ago

Hi, @gattia I am confused at one point that output of network after softmax will be [N,C,H,W] where C is the number of classes we want to predict. Each [N,_,H,W] will contain double values which will not be equal to values given in ground truth. Let in ground truth classes are labeled as 1,2,3. So how can i find the intersection of a tensor of double values with a tensor of integer values?

gattia commented 6 years ago

I'm not sure if you are confused about the actual one-hot-encoding or about the fact that the dice is calculated using integer values from ground-truth data and a decimal of some sort from the prediction of the network.

If it's the one-hot-encoding, the first thing you should be doing is to one-hot-encoding the ground truth. After you run the one-hot-encoding function I included, each channel should have only 0's and 1's. And the channel number should coincide with the label of interest. So, instead of you having 2's wherever there is a tumour (I'm making this example up) you will have 1's on the second channel. Everywhere that there isn't a tumour should be 0s (on that same second channel).

If it's the fact that the network outputs a decimal (between 0 and 1) while the ground truth is either a 0 or 1. This is how it learns. As the decimal gets closer and closer to the correct value (0 or 1 depending on the case) the dice score get's better and better. So, when you are done creating a segmentation, you will have to decide at what threshold a pixel should be considered to be part of a particular tissue of interest. Typically, if the prediction is >0.5 I make it that tissue, but I'm sure others use different cutoffs for different reasons.

If it's neither of these, you'll have to re-specify the problem because I dont understand.