MIC-DKFZ / nnUNet

Apache License 2.0
5.79k stars 1.74k forks source link

Integrate nnUNet trained model in python code #1873

Closed eyalhana closed 9 months ago

eyalhana commented 9 months ago

Hi, Thank you for the excellent repository!

I have a question: Is there a way to integrate a trained model directly into Python code? Currently, I employ the following command in terminal:

nnUNetv2_predict -d Dataset510_Testsplits_cardiac -i "$input_data" -o "$output_data" -f 0 1 2 3 4 -tr nnUNetTrainer -c 2d -p nnUNetPlans --save_probabilities

I am interested in using the model's forward pass directly inside python script. For instance: segmentation_numpy = model(image_numpy) Is this integration possible?

Thanks!

TaWald commented 9 months ago

nnU-Net predict does not only do a forward pass. Your predict does preprocessing of raw_cases, forward passes, ensembling. This is way more than a normal forward pass of the network itself. If you are interested in using the prediction or the probabilities of nnunet you should just call the terminal command from your python script.

If you have an actual use-case where one needs direct exposure to the model you can ofcourse reopen this issue and explain why this would be needed

TaWald commented 9 months ago

For the best integration though you can checkout the Readme of the inference nnU-Net provides. This is probably the closest you can get to is explained in: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunetv2/inference/readme.md

    from nnunetv2.paths import nnUNet_results, nnUNet_raw
    import torch
    from batchgenerators.utilities.file_and_folder_operations import join
    from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor

    # instantiate the nnUNetPredictor
    predictor = nnUNetPredictor(
        tile_step_size=0.5,
        use_gaussian=True,
        use_mirroring=True,
        perform_everything_on_device=True,
        device=torch.device('cuda', 0),
        verbose=False,
        verbose_preprocessing=False,
        allow_tqdm=True
    )
    # initializes the network architecture, loads the checkpoint
    predictor.initialize_from_trained_model_folder(
        join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_lowres'),
        use_folds=(0,),
        checkpoint_name='checkpoint_final.pth',
    )
    # variant 1: give input and output folders
    predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'),
                                 join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'),
                                 save_probabilities=False, overwrite=False,
                                 num_processes_preprocessing=2, num_processes_segmentation_export=2,
                                 folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)

Once you have initialized the predictor you can wrap it as you then call your preferred predictions

TaWald commented 9 months ago

Closing due to inactivity