Open pooya-mohammadi opened 3 weeks ago
I also saved the data before passing to export_prediction_from_logits
and ran it in a separate process and the output took the same amount of time while only 10-20 percent of the CPU is occupied!
I noticed that Totalsegmentator uses the same function, however, they resize/resample the inputs before passing it to nnunet predictor.
And the reason nnunet takes very long is because it resizes each class label one by one in this code snippet:
for c in range(data.shape[0]):
reshaped_final[c] = resize_fn(data[c], new_shape, order, **kwargs)
@FabianIsensee Do you have any plans to change this or use a method like Totalsegmentator? Or is there anyway to skip this part? For segmentations with large number of classes this becomes very slow.
The reason TS2 is fast is because it uses NN interpolation (creating less smooth masks). nnUNet uses tricubic resampling (or 3-linear resampling if you configure your training to use torch resampling).
@ancestor-mithril Thanks for your response, can you show me how to do it? However this code snippet is too slow:
for c in range(data.shape[0]):
reshaped_final[c] = resize_fn(data[c], new_shape, order, **kwargs)
Do you have any comments on this?
One more thing, since TS2 does not save probabilities, there is no need to do the resize on all the classes separately. And it only does a single zoom on the output segmentation.
You can change the resampling by using a different experiment planner (see Documentation). For example, you can use https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/experiment_planning/experiment_planners/resampling/resample_with_torch.py.
One more thing, since TS2 does not save probabilities, there is no need to do the resize on all the classes separately. And it only does a single zoom on the output segmentation.
This is not supported by nnUNet, you would have to change the code to obtain the TS2 behavior.
@ancestor-mithril Exactly, I noticed that when the return_probabilities
is set to False, it would make it faster to get the segmentation and then send it to resampling with order=0
resize. This drastically increases the inference time.
I'll create a pull request.
The prediction process takes about 10-20 seconds on my system: 4090GPU
100%|██████████| 120/120 [00:16<00:00, 7.30it/s]
But the this function export_prediction_from_logits is super slow on my system with 32 cores of CPU[INFO] Elapsed time for export_prediction_from_logits: 534.6257283687592
I added a simple time tracker. What could have caused this?The number of the classes is equal to the number of the classes of Totalsegmentator for whole all items. I run the Totalsegmentator on the same image and it takes less than 30 seconds to finish!
Can any one help me on this?