MIC-DKFZ / nnUNet

Apache License 2.0
5.9k stars 1.76k forks source link

Multilabel Segmentation #1899

Closed AniketP04 closed 8 months ago

AniketP04 commented 10 months ago

Dear nnUNet devs,

I am seeking assistance on performing multilabel segmentation using nnUNet for 5 classes. If anyone can help, please feel welcome to provide guidance. My label samples are like this. Label sample are :

mask_color

RajMalik23 commented 10 months ago

I'm encountering a similar issue. Can nnUNet be used for multi-label segmentation? Your assistance would be greatly appreciated.

I encountered a shape mismatch error when performing one-hot encoding on mask images with five classes. The final encoded mask image has dimensions (256, 256, 5), while the corresponding original image has dimensions (256, 256, 3). Additionally, during training, I observed that integer encoding on a small dataset of 15 samples resulted in a significant time per epoch, approximately 4200 seconds.

mrokuss commented 9 months ago

Hey!

nnUNet works out of the box with multiple labels, as long as they are is ascending order, i.e. [0,1,2,3,4] is fine but not [0,1,3,5,9]. Make sure to correctly add the labels in the dataset.json as well, see here.

@RajMalik23 The shapes you describe seem like (256, 256, 5): A segmentation mask with 5 classes, and (256, 256, 3): An RGB image with the typical three color channels. If the image is black and white you might condense this to one channel otherwise specify the three color channels in the dataset.json. Usually you just have to get your dataset in the right format and nnUNet handels the rest!

Could this help?

Max

RajMalik23 commented 9 months ago

Hello @mrokuss,

I'm @RajMalik23, and I'd like to share the content of my dataset.json file:

{
    "channel_names": {
        "0": "R",
        "1": "G",
        "2": "B"
    },
    "labels": {
        "background": 0,
        "Unlabeled": 1,
        "Benign": 2,
        "Malignant": 3,
        "Artefact": 4
    },
    "numTraining": 15,
    "file_ending": ".png",
    "name": "Dataset050_Prostate"
}

However, when I initiate training for the first epoch, it takes an unexpectedly long time—around 4000 seconds—even though I'm training on a small dataset of just 15 samples. This prolonged time per epoch is concerning.

I'm unsure about what could be causing this delay or if there's something incorrect in my implementation. I would greatly appreciate your assistance in resolving this issue.

Thank you.

mrokuss commented 9 months ago

Hey @RajMalik23

So far your dataset.json looks good and this is a weird behavior. Did you then run nnUNetv2_plan_and_preprocess -d 50 -c 2d --verify_dataset_integrity and everything ran through smoothly without warnings? May I ask how large the images are and could you please supply me with the generated plans file? What kind of machine (CPU threads, RAM, GPU) are you running this on?

Best,

Max

RajMalik23 commented 9 months ago

Hello Max,

Yes, I have executed the nnUNetv2_plan_and_preprocess -d 50 -c 2d --verify_dataset_integrity command, and it completed without any warnings. I've uploaded all the generated files from both nnUNetv2_plan_and_preprocess -d 50 -c 2d --verify_dataset_integrity and nnUNetv2_train DATASET_NAME_OR_ID 2d FOLD commands to this temporary repository: https://github.com/RajMalik23/Issue1899-nnUNet. You'll find only the files from the 1st epoch in the repository because I interrupted the training midway due to the extensive time it was taking.

I'm currently running the training on CPU threads to ensure that everything is functioning correctly. The plan is to shift to GPU for the complete dataset once I confirm the model's performance on this smaller set of 15 samples.

Thank you for your assistance.

Best regards, Raj Malik

mrokuss commented 9 months ago

Hey @RajMalik23

I tested the dataset you supplied (nnUnet_preprocessed) on the GPU of my machine, NVIDIA RTX 3090, and it runs through smoothly. I get epoch times of about 12 seconds with the pseudo dice quickly improving. Except for the last class I get a nan pseudo dice which is probably due to the fact that this class does not appear in the validation split.

What might resolve your issue: Try actually training on the GPU. nnUNet does not care about the dataset size and will always do 250 iterations per epoch on the training set no matter how large.

Hope this helps!

Max