MIC-DKFZ / nnUNet

Apache License 2.0
5.82k stars 1.74k forks source link

Region-based training cannot be used with ensemble #2401

Closed gaojh135 closed 1 month ago

gaojh135 commented 3 months ago

Hello, I have some questions about region based training. I am working on a task of segmenting myocardial scars and edema, which I will divide into two stages: coarse segmentation and fine segmentation.

In coarse segmentation, when I run the following command like this:

nnUNetv2_find_best_configuration 102 -c 2d 3d_fullres -p nnUNetResEncUNetLPlans

The output is

***All results:***
nnUNetTrainer__nnUNetResEncUNetLPlans__2d: 0.9073983062528307
nnUNetTrainer__nnUNetResEncUNetLPlans__3d_fullres: 0.907752261947338
ensemble___nnUNetTrainer__nnUNetResEncUNetLPlans__2d___nnUNetTrainer__nnUNetResEncUNetLPlans__3d_fullres___0_1_2_3_4: 0.9111222632429284
*Best*: ensemble___nnUNetTrainer__nnUNetResEncUNetLPlans__2d___nnUNetTrainer__nnUNetResEncUNetLPlans__3d_fullres___0_1_2_3_4: 0.9111222632429284

It is correct and can be used with ensemble. But in fine segmentation, when I run the following command like this: nnUNetv2_find_best_configuration 102 -c 2d 3d_fullres -p nnUNetResEncUNetLPlans The output is

***All results:***
nnUNetTrainer__nnUNetResEncUNetLPlans__2d: 0.8446568426083821
nnUNetTrainer__nnUNetResEncUNetLPlans__3d_fullres: 0.8513131411230015
ensemble___nnUNetTrainer__nnUNetResEncUNetLPlans__2d___nnUNetTrainer__nnUNetResEncUNetLPlans__3d_fullres___0_1_2_3_4: 0.022358684648249028
*Best*: nnUNetTrainer__nnUNetResEncUNetLPlans__3d_fullres: 0.8513131411230015

The result of ensemble configuration is clearly incorrect. Does region-based training not support ensemble?

gaojh135 commented 3 months ago

dataset.json for coarse segmentation

 ……
 "labels": {  
   "background": 0,
   "myo": 1,
   "lv": 2,
   "rv": 3
  }, 
 ……

dataset.json for fine segmentation

……
 "labels": {  
   "background": 0,
   "myo": [1, 2, 3],
   "edema&scar":[2, 3],
   "scar":[3]
  }, 
 "regions_class_order": [1, 2, 3],
 ……

I have already processed the labels to ensure that the training dataset labels are correct in fine segmentation.

Hello, I have some questions about region based training. I am working on a task of segmenting myocardial scars and edema, which I will divide into two stages: coarse segmentation and fine segmentation.

In coarse segmentation, when I run the following command like this:

nnUNetv2_find_best_configuration 102 -c 2d 3d_fullres -p nnUNetResEncUNetLPlans

The output is

***All results:***
nnUNetTrainer__nnUNetResEncUNetLPlans__2d: 0.9073983062528307
nnUNetTrainer__nnUNetResEncUNetLPlans__3d_fullres: 0.907752261947338
ensemble___nnUNetTrainer__nnUNetResEncUNetLPlans__2d___nnUNetTrainer__nnUNetResEncUNetLPlans__3d_fullres___0_1_2_3_4: 0.9111222632429284
*Best*: ensemble___nnUNetTrainer__nnUNetResEncUNetLPlans__2d___nnUNetTrainer__nnUNetResEncUNetLPlans__3d_fullres___0_1_2_3_4: 0.9111222632429284

It is correct and can be used with ensemble. But in fine segmentation, when I run the following command like this: nnUNetv2_find_best_configuration 102 -c 2d 3d_fullres -p nnUNetResEncUNetLPlans The output is

***All results:***
nnUNetTrainer__nnUNetResEncUNetLPlans__2d: 0.8446568426083821
nnUNetTrainer__nnUNetResEncUNetLPlans__3d_fullres: 0.8513131411230015
ensemble___nnUNetTrainer__nnUNetResEncUNetLPlans__2d___nnUNetTrainer__nnUNetResEncUNetLPlans__3d_fullres___0_1_2_3_4: 0.022358684648249028
*Best*: nnUNetTrainer__nnUNetResEncUNetLPlans__3d_fullres: 0.8513131411230015

The result of ensemble configuration is clearly incorrect. Does region-based training not support ensemble?

antonie-z commented 2 months ago

I am also experiencing a similar issue where single model training can be properly evaluated with dice, but the ensemble degrades to an extremely low dice (such as 0.02). This problem only occurs with region-based ensembles. Later, I found the reason for the error is due to the repetition of apply_inference_nonlin in the region-based ensemble:

  1. The first apply_inference_nonlin occurs when saving the .npz file (in nnuunetv2.inference.export_prediction.convert_predicted_logits_to_segmentation_with_correct_shape) with the line of code

    predicted_probabilities = label_manager.apply_inference_nonlin(predicted_logits)

    Then, the probability map is saved in the .npz file through the line

    predicted_probabilities = label_manager.revert_cropping_on_probabilities(predicted_probabilities,
    properties_dict['bbox_used_for_cropping'],properties_dict['shape_before_cropping'])
  2. The .npz file saved in step 1 has already been processed by apply_inference_nonlin(torch.sigmoid) and can be segmented using a threshold (0.5). However, in the ensemble process (nnuunetv2.ensembling.ensemble.merge_files), the linesegmentation = label_manager.convert_logits_to_segmentation(probabilities)repeats label_manager.apply_inference_nonlin(predicted_logits), which leads to the code segmentation[predicted_probabilities[i] > 0.5] = c in LabelManager's convert_probabilities_to_segmentation being almost impossible to determine the correct segmentation threshold.

    segmentation = label_manager.convert_logits_to_segmentation(probabilities)
    def convert_logits_to_segmentation(self, predicted_logits: Union[np.ndarray, torch.Tensor]) -> \
               Union[np.ndarray, torch.Tensor]:
           input_is_numpy = isinstance(predicted_logits, np.ndarray)
           probabilities = self.apply_inference_nonlin(predicted_logits)
           if input_is_numpy and isinstance(probabilities, torch.Tensor):
               probabilities = probabilities.cpu().numpy()
           return self.convert_probabilities_to_segmentation(probabilities)
  3. My suggestion is to modify the inference_nonlin to an identity function before executing the line segmentation = label_manager.convert_logits_to_segmentation(probabilities)in nnuunetv2.ensembling.ensemble.merge_files, like so:

    def identity_function(logits: torch.Tensor)-> torch.Tensor:
       return logits
    
    if label_manager.has_regions:
       label_manager.inference_nonlin = identity_function
       segmentation = label_manager.convert_logits_to_segmentation(probabilities)
    

    This would allow for the correct execution of region-based ensemble.

gaojh135 commented 2 months ago

Thank you very much for your reply, it's working now.