MIC-DKFZ / nnUNet

Apache License 2.0
5.86k stars 1.75k forks source link

Guide to using nnU-Net with overlapping labels #653

Closed Mufid99 closed 1 year ago

Mufid99 commented 3 years ago

Hello, I am new to using nnU-Net and I wanted to say that it is great! Up till now, I was able to use it to segment a single label on my own 2D dataset pretty easily. But now, I want to use it to segment multiple labels on the same dataset however the labels would be overlapping. I read in other issues that there is a workaround for this by following something similar to what was done in the BraTS 2020 challenge, so I just wanted to ask if you have a guide of what I can especially look into to get a good grasp of how to allow for overlapping labels (i.e., a section in the paper or files I should look into). Thank you very much!

FabianIsensee commented 3 years ago

This would be rather difficult to explain in detail. To give you food for thought: imaging you have label 1 and label 2 and they overlap, the solution would be to create a third label, label 3 which is defined as (label1 & label2 ) (logical AND, label 3 is where the two overlap). Then when you train, you use region-based training as in BraTS and define your regions as ((1, 3), (2, 3)). That will tell nnU-Net to merge labels 1 and 3 into region1 and 2 and 3 into region 2. nnU-Net will then train on these overlapping regions. Best, Fabian

siavashk commented 1 year ago

@FabianIsensee I am facing a similar issue but I do not think your proposed solution would work for my use case. For context, I am segmenting a family of pathologies of an anatomy. There are 8 pathologies (classes) in my dataset. However, most pathologies do not appear in isolation, frequently 2 or more pathologies appear together at the same location.

Your solution would require the power set of all classes as labels for training. In the general case, for 8 classes, this would be 2 ** 8 = 266 training labels. Of course, it is rare for more than 3 pathologies to appear together, so if we limit labels to at most 3 simultaneous pathologies we will have C(8, 1) + C(8, 2) + C(8, 3) = 8 + 28 + 56 = 92 training labels.

There is probably a smarter way of doing this. Do you have any recommendations?

FabianIsensee commented 1 year ago

Hey, yes there is and we are working on it. But time is tight so it might take a while for this to come out in nnU-Net. Basically you can define each of your labels as a separate binary classification with a sigmoid function + Dice&BCE loss (instead of softmax + Dice&CE loss)

siavashk commented 1 year ago

Thank you. I just wanted to say for a repository that seems to be maintained by one person you are extremely fast in replying. Your work has been invaluable to the community.

vicoso commented 1 year ago

Hi! I want to try to use nnUNet for creating a model, which will identify kidneys from their MRI images. And I got a question about labeling:

the situation is that in the dataset for training images, I have three modalities (T1, T2, dixon) for each patient. For labeling I have also three modalities, so I have masks for T1, for T2, and dixon. However, as far as I understood, from the example of structuring the Task: in labels you do not specify the modalities. For this reason, I am not sure which out of three modalities in labels should I use for training, and if it is in general ok to use a mask from T1 for training on the image from T2 as well?

coendevente commented 1 year ago

Hey!

This would be rather difficult to explain in detail. To give you food for thought: imaging you have label 1 and label 2 and they overlap, the solution would be to create a third label, label 3 which is defined as (label1 & label2 ) (logical AND, label 3 is where the two overlap). Then when you train, you use region-based training as in BraTS and define your regions as ((1, 3), (2, 3)). That will tell nnU-Net to merge labels 1 and 3 into region1 and 2 and 3 into region 2. nnU-Net will then train on these overlapping regions. Best, Fabian

I just wanted to add a more extensive explanation of this comment from @FabianIsensee for visual learners like myself.

The image below shows how to use nnU-Net in a multilabel setting on the right, with the corresponding "BraTS Region"-equivalents from here on the left:

image

The multilabel example has three labels, resulting in 2**3=8 possible label values (the image shows 7, as it doesn't show the background label).

Also, please note that this multilabel workaround does not really require setting regions_class_order to an appropriate value, as you would want to be using the probability files that nnU-Net writes anyway, not the integer maps.

Here's some code to write the dataset.json file for any multilabel problem:

from itertools import chain, combinations
import json

nnunet_dataset_json_file = '/path/to/dataset.json'  # replace this with your output file for dataset.json
n_labels = 13  # replace this with how many classes you have
numTraining = 0  # replace this with how many training images you have

def powerset(iterable):
    # powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

labelscomb = list(powerset(range(n_labels)))
multilabels = range(len(labelscomb))
labelscomb_to_multilabel = dict(zip(labelscomb, multilabels))

"""
Now: len(labelscomb_to_multilabel) == 2 ** 13 == 8192

And `labelscomb_to_multilabel` will look something like this:
{(): 0,
 (0,): 1,
 (1,): 2,
 (2,): 3,
 ...
 (0, 1): 14,
 (0, 2): 15,  # meaning that when a voxel has label 0 and 2 overlap, it will get the integer value 15 in the nnU-Net nifti label file.
 (0, 3): 16,
 ...
 (4, 6, 9, 12): 999,
 ...
}

So `labelscomb_to_multilabel` is basically a look-up table for converting a combination of labels to the huge integer value that will be used in the nnU-Net nifti label file.
"""

# Writing labels for dataset.json
labels = {'background': 0}
labels = {**labels, **{str(biomarker_idx): [] for biomarker_idx in range(n_labels)}}

for biomarker_ids, multilabel in labelscomb_to_multilabel.items():
    for biomarker_idx in biomarker_ids:
        labels[str(biomarker_idx)].append(multilabel)

"""
`labels` now looks something like this:
{'background': 0,
 '0': [1,
  14,  # 14 is here ...
  15,
  16,
  17,
  ...
  ],
  '1': [2,
  14,  # ... and here, as 14 actually means exactly and only label '0' and '1' are co-located.
  26,
  27,
  28,
  29,
  ...
  ],
  ...
}
"""

dataset_dict = { 
    "channel_names": {  # formerly modalities
        "0": "OCT",
    }, 
    "labels": labels, 
    "numTraining": numTraining, 
    "file_ending": ".nii.gz",
    "regions_class_order": list(range(1, n_labels + 1)),
    # Add more meta info if needed
}

print('nnunet_dataset_json_file:', nnunet_dataset_json_file)

with open(nnunet_dataset_json_file, 'w') as f:
    f.write(json.dumps(dataset_dict, indent=4))

And here's a simple - but highly inefficient - function to convert your multilabel image (as a numpy array) to the correct format that nnU-Net can use (outputting a numpy array, so you still need to convert it to e.g. a SimpleITK object / nifti file):

import numpy as np

def image_to_nnunet_multilabel(input_volume):
    # Note `input_volume` should be one-hot encoded, so input_volume.shape == (n_labels, D, H, W)
    nnunet_output_volume = np.zeros(input_volume.shape[1:], dtype=int)

    for z in range(input_volume.shape[1]):
        for y in range(input_volume.shape[2]):
            for x in range(input_volume.shape[3]):
                labels_in_voxel = tuple(np.argwhere(input_volume[:, z, y, x])[:, 0])
                nnunet_output_volume[z, y, x] = labelscomb_to_multilabel[labels_in_voxel]

    return nnunet_output_volume

And a dummy example of how this function could be used:

a = np.zeros((13, 10, 10, 10))
a[0, :2, :2, :2] = 1
a[5, :1, :1, :1] = 1
b = image_to_nnunet_multilabel(a)

# a.shape, b.shape, np.unique(a), np.unique(b) == ((13, 10, 10, 10), (10, 10, 10), array([0., 1.]), array([ 0,  1, 18]))

The first slice of b now looks like this: image

rahulghosh2 commented 11 months ago

coendevente Thank you for this helpful explanation, can you clarify what you mean by:

"this multilabel workaround does not really require setting regions_class_order to an appropriate value, as you would want to be using the probability files that nnU-Net writes anyway, not the integer maps."

My goal is to get a separate binary segmentation for each label in my dataset. I have 3 labels in which 2 of them may overlap. How you access this probability file rather than the integer map? Is the shape of the probability file the same shape as the 1-hot encoding (for example, (3, 512, 512) for 3 labels of shape (512, 512))? Thank you for advising.

coendevente commented 11 months ago

coendevente Thank you for this helpful explanation, can you clarify what you mean by:

"this multilabel workaround does not really require setting regions_class_order to an appropriate value, as you would want to be using the probability files that nnU-Net writes anyway, not the integer maps."

My goal is to get a separate binary segmentation for each label in my dataset. I have 3 labels in which 2 of them may overlap. How you access this probability file rather than the integer map? Is the shape of the probability file the same shape as the 1-hot encoding (for example, (3, 512, 512) for 3 labels of shape (512, 512))? Thank you for advising.

Assuming your class 2 and 3 can overlap, and if you set regions_class_order to for example [1, 2, 3], this will result in the model always overriding class 2 with class 3 (during inference). That is not desired if you also want your model to handle overlapping classes during inference, of course. Therefore, you'll likely want to read the probability files instead of the integer maps that nnU-Net creates.

This file should indeed have a shape like (3, 512, 512). You can let nnU-Net create such a probability file by providing --save_probabilities when using nnUNetv2_predict.

yarinbar commented 10 months ago

Hey, yes there is and we are working on it. But time is tight so it might take a while for this to come out in nnU-Net. Basically you can define each of your labels as a separate binary classification with a sigmoid function + Dice&BCE loss (instead of softmax + Dice&CE loss)

Any updates on that? I've been trying to do what you described but the code is very complicated and i find it really difficult not to break everything :/

If i understand correctly, using a [(B,) n_cls, H, W] masks requires me to change:

  1. the dataloaders to be able to handle it
  2. Use force_use_labels = True somehow combined with the label_dict so we get has_regions=True
  3. Change the loss function to be able to calculate each channel's loss and then combine them

Am i missing something?

avzh1 commented 5 months ago

And here's a simple - but highly inefficient - function to convert your multilabel image (as a numpy array) to the correct format that nnU-Net can use (outputting a numpy array, so you still need to convert it to e.g. a SimpleITK object / nifti file):

Hey, thanks for the great explanation. For a more vectorized approach for the mentioned section in your code I've rewritten it as this for anyone who wants to adapt this:

def operation(x):
        labels_in_voxel = tuple(np.argwhere(x)[:, 0] + 1)
        return labelscomb_to_multilabel[labels_in_voxel]

def image_to_nnunet_multilabel(input_volume):
    # Note `input_volume` should be one-hot encoded, so input_volume.shape == (n_labels, D, H, W)
    input_volume_per_channel = input_volume.reshape(n_labels, -1).T
    input_volume_translated = np.apply_along_axis(operation, 1, input_volume_per_channel)
    return input_volume_translated.reshape(input_volume.shape[1:])

to test:

# quick test

# create an array of values between 0 and 1
gt_slice = np.random.rand(20, 40, 45)
# threshold these values to 0 if below 0.5 or 1 otherwise
gt_slice = np.where(gt_slice < 0.5, 0, 1)
# repeat this slice across 7 dimensions
gt = np.repeat(gt_slice[np.newaxis, :, :, :], 7, axis=0)
assert np.array_equal(gt[0], gt_slice)

output = image_to_nnunet_multilabel(gt)
assert gt_slice.shape == output.shape
assert np.array_equal(gt_slice * 127, output)