fepegar / torchio

Medical imaging toolkit for deep learning
https://torchio.org
Apache License 2.0
2.07k stars 240 forks source link

How to handle multi-class segmentation? #237

Closed meghbhalerao closed 4 years ago

meghbhalerao commented 4 years ago

Hi everyone, How does TorchIO handle multi class segmentation? For example, if I am training a 3D Semantic Segmentation Model with multiple classes, my ground truth label would be a 4D one-hot encoded label like this: num_classes * H * W * D. For example in this snippet:

one_subject = Subject(
        T1=Image('../BRATS2018_crop_renamed/LGG75_T1.nii.gz', torchio.INTENSITY),
        T2=Image('../BRATS2018_crop_renamed/LGG75_T2.nii.gz', torchio.INTENSITY),
        label=Image('../BRATS2018_crop_renamed/LGG75_Label.nii.gz', torchio.LABEL),
    )

What is the kind of mask does the label expect? Should it be a one-hot encoded mask or can it be a 3D mask itself, and tochio would handle the one-hot encoding part internally? Thank you.

romainVala commented 4 years ago

Hello this is a missing feature, but it should be addressed soon (given the recent needs) see related issue

213

not the same data behind, but the same need: 4D support #234

may be @GFabien can help ? if @fepegar has not already solved it ...

meghbhalerao commented 4 years ago

Hi @romainVala, Thank you for your response. Yes, need support for multi-class segmentation.. Should we add this as a feature request? Thanks, Megh

meghbhalerao commented 4 years ago

Also, if I am not wrong, right now, there is support only for binary masks right?

fepegar commented 4 years ago

Hi @meghbhalerao,

As @romainVala said, we don't support 4D images yet. Are your pixels assigned to multiple classes? If they are, you can add each class as an individual image for now. If they are not, you can create a 3D label map and add it to the Subject. Would any of those work for you?

sarthakpati commented 4 years ago

Hi @fepegar,

I believe the question is the ability to predict multiple classes. Take the BraTS dataset, for example, where each segmentation file has 0,1,2,4. Right, @meghbhalerao?

meghbhalerao commented 4 years ago

Hi @fepegar, Yes, I am doing a multi class segmentation. Also, I am not very clear what a 3D label map exactly means. Does it mean a 3D image where each pixel in the mask is assigned to a particular class (and this need not be one hot encoded - could be class indices like @sarthakpati mentioned e.g. 0,1,2,4)?

GFabien commented 4 years ago

Hi @meghbhalerao, Yes a 3D label map is exactly that. From what I have experienced so far with torchio, handling a 3D label map works very well. Torchio is expecting such kind of representation. It becomes more tricky when your data cannot be represented as a 3D label map when, for example, you authorize your voxels to be in different classes (fuzzy clustering) to account for partial volumes (PV).

The current solution is to create an image in the subject for each PV label map. However, it becomes impossible to deal with them jointly in the transforms without implementing your own transforms or modifying the current implementation of torchio.

GFabien commented 4 years ago

Sorry, I was wrong in my last message, transforms are ok.

Problems come when you use patch based approaches and you want to define some sampling strategies based on the values of your PV label maps. We wanted to have the same probabilty to draw a patch for every class, to do so you just have to take your different PV label maps, get the total volume occupied per class (the sum over a PV label map) and sum all your maps weighted by the inverse of the volumes. The current solution to do so would be to create a key in your sample with the probability map and call the WeightedSampler with this key.

Maybe an implementation like the one used for the masking_method used in some scaling transforms would be a good idea. But it would need to accept the whole sample and not only the Image. This would bring the flexibility to deal with multiple PV label maps and to address the possibly complicated ways of creating the probability map out of (PV) label map(s).

fepegar commented 4 years ago

You can use torchio.LabelMap for your segmentations. They can be one-hot or not, as everything is 4D after #238.