MIC-DKFZ / nnUNet

Apache License 2.0
5.83k stars 1.75k forks source link

export_prediction_from_logits takes a very long time to finish #2540

Open pooya-mohammadi opened 3 weeks ago

pooya-mohammadi commented 3 weeks ago

The prediction process takes about 10-20 seconds on my system: 4090GPU 100%|██████████| 120/120 [00:16<00:00, 7.30it/s] But the this function export_prediction_from_logits is super slow on my system with 32 cores of CPU [INFO] Elapsed time for export_prediction_from_logits: 534.6257283687592 I added a simple time tracker. What could have caused this?

The number of the classes is equal to the number of the classes of Totalsegmentator for whole all items. I run the Totalsegmentator on the same image and it takes less than 30 seconds to finish!

Can any one help me on this?

pooya-mohammadi commented 2 weeks ago

I also saved the data before passing to export_prediction_from_logits and ran it in a separate process and the output took the same amount of time while only 10-20 percent of the CPU is occupied!

I noticed that Totalsegmentator uses the same function, however, they resize/resample the inputs before passing it to nnunet predictor.

And the reason nnunet takes very long is because it resizes each class label one by one in this code snippet:

for c in range(data.shape[0]):
    reshaped_final[c] = resize_fn(data[c], new_shape, order, **kwargs)

@FabianIsensee Do you have any plans to change this or use a method like Totalsegmentator? Or is there anyway to skip this part? For segmentations with large number of classes this becomes very slow.

ancestor-mithril commented 2 weeks ago

The reason TS2 is fast is because it uses NN interpolation (creating less smooth masks). nnUNet uses tricubic resampling (or 3-linear resampling if you configure your training to use torch resampling).

pooya-mohammadi commented 2 weeks ago

@ancestor-mithril Thanks for your response, can you show me how to do it? However this code snippet is too slow:

for c in range(data.shape[0]):
    reshaped_final[c] = resize_fn(data[c], new_shape, order, **kwargs)

Do you have any comments on this?

pooya-mohammadi commented 2 weeks ago

One more thing, since TS2 does not save probabilities, there is no need to do the resize on all the classes separately. And it only does a single zoom on the output segmentation.

ancestor-mithril commented 2 weeks ago

You can change the resampling by using a different experiment planner (see Documentation). For example, you can use https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/experiment_planning/experiment_planners/resampling/resample_with_torch.py.

ancestor-mithril commented 2 weeks ago

One more thing, since TS2 does not save probabilities, there is no need to do the resize on all the classes separately. And it only does a single zoom on the output segmentation.

This is not supported by nnUNet, you would have to change the code to obtain the TS2 behavior.

pooya-mohammadi commented 2 weeks ago

@ancestor-mithril Exactly, I noticed that when the return_probabilities is set to False, it would make it faster to get the segmentation and then send it to resampling with order=0 resize. This drastically increases the inference time. I'll create a pull request.