ellisdg / 3DUnetCNN

Pytorch 3D U-Net Convolution Neural Network (CNN) designed for medical image segmentation
MIT License
1.89k stars 648 forks source link

About Evaluation metrics for BraTS example #316

Open tarekegn82 opened 1 year ago

tarekegn82 commented 1 year ago

Dear Ellis,

Thank you for your earlier responses, I have managed to run fivefold cross-validation training on the BraTS dataset. As you know the dice metric used here is the loss function. Similarly, can we generate a dice coefficient score and HD of each tumor class (WT, ET, and TC)?

Best.

ellisdg commented 1 year ago

Update 2/9/24 - See this gist for a working example on how to use MONAI to evaluate predictions: https://gist.github.com/ellisdg/df06194d902e0fe70ba8d9ecbb1d0cb5

Hi @tarekegn82, a simple way to generate Dice and HD is to use Monai. I think something along the lines of the following should work:

from monai.transforms import LoadImage from monai.metrics import DiceMetric, HausdorffDistanceMetric from unet3d.utils import compile_one_hot_encoding import torch

loader = LoadImage(image_only=True, ensure_channel_first=True)

pred_filename = "/path/to/filename.nii.gz" label_filename = "/path/to/label_filename.nii.gz"

pred = loader(pred_filename)[None] # "None" makes the image 5D labelmap = loader(label_filename)[None] # "None" makes the image 5D

# convert the label map into a onehot image (WT, TC, ET) onehot = compile_one_hot_encodeing(labelmap, n_labels=3, labels=[[2, 1, 4], [1, 4], 4]) # note that the labels are organized in a hierarchy! WT equals 2, 1 and 4, TC equals 1 and 4, and ET is just 4. # This is important as otherwise you will not get the correct masks # Also, the labels have changed for the more recent BraTS challenges, so make sure you are using the correct labels for # your dataset.

# threshold the prediction pred_labels = pred >= 0.5

~~for channel in range(3): ~~ _pred = pred_labels[:, channel][:, None]~~ ~~ _truth = onehot[:, channel][:, None]~~ ~~ dice = DiceMetric()(_pred, _truth)~~ ~~ hd = HausdorffDistanceMetric()(_pred, _truth)~~

My plan is to write a function to do this automatically when cross validation is run, but I haven't found the time to do it yet. But if you or someone else is interested in writing the code to have it run during cross validation, that would be awesome!

Updated 1/16/2024 - I rewrote the code as using argmax was the wrong suggestion for hierarchical labels. Note that if you are not using hierarchical labels, the code will look different.

tarekegn82 commented 1 year ago

Thank you for the tip, I will try it this way and see if it works. I will also try to write the code into the cross-validation pipeline. In the meantime, please let me know if you do the coding for the latter.

ellisdg commented 8 months ago

I've edited the previous comment to correct my suggestion to use argmax, which was an error on my part.

I added some comments to clear things up about converting predictions with a hierarchy into a label map. Here is the snipped of code with comments: https://github.com/ellisdg/3DUnetCNN/blob/master/unet3d/utils/one_hot.py#L101