MIC-DKFZ / nnUNet

Apache License 2.0
5.9k stars 1.76k forks source link

Is there any way to serve the trained model with torchserve? #1219

Open atisman89 opened 2 years ago

atisman89 commented 2 years ago

I trained a model for only one fold and tested it with nnUNet_predict, which seems to show a good performance. So I am wondering if there's a way to serve the trained model using torchserve. The trained model I tried to use is located at: nnUNet_trained_models/nnUNet/3d_fullres/.../nnUNetTrainerV2__nnUNetPlansv2.1/fold_0

I first tried to use the output model model_final_checkpoint.model directly:

torch-model-archiver --model-name <my-model> --version 0.9 --serialized-file model_final_checkpoint.model --handler image_segmenter
mv <my-model>.mar ./model-store/
torchserve --start --ncs --model-store model-store --models <my-model>.mar

which was not successful due to the following error:

2022-11-13T00:16:33,159 [INFO ] W-9000-nnunet-lungseg64_0.9-stdout MODEL_LOG -   File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/ts/torch_handler/base_handler.py", line 115, in _load_torchscript_model
2022-11-13T00:16:33,161 [INFO ] W-9000-nnunet-lungseg64_0.9-stdout MODEL_LOG -     return torch.jit.load(model_pt_path, map_location=self.device)
2022-11-13T00:16:33,161 [INFO ] W-9000-nnunet-lungseg64_0.9-stdout MODEL_LOG -   File "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torch/jit/_serialization.py", line 162, in load
2022-11-13T00:16:33,162 [INFO ] W-9000-nnunet-lungseg64_0.9-stdout MODEL_LOG -     cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files)
2022-11-13T00:16:33,162 [INFO ] W-9000-nnunet-lungseg64_0.9-stdout MODEL_LOG - RuntimeError: PytorchStreamReader failed locating file constants.pkl: file not found

So the trained model didn't seem to be in TorchScript format, and I changed the nnunet code, so that I can save the trained model in TorchScript format like (in nnunet/training/network_training/network_trainer.py):

    def save_final_checkpoint_in_torchscript(self):
        self.load_final_checkpoint()
        self.network.eval()
        model_scripted = torch.jit.script(self.network)
        scripted_model_path = join(self.output_folder, 'model_final_checkpoint_scripted.pt')
        model_scripted.save(scripted_model_path)

which caused an error while running torch.jit.script:

RuntimeError: 
Expected integer literal for index. ModuleList/Sequential indexing is only supported with integer literals. Enumeration is supported, e.g. 'for index, v in enumerate(self): ...':
  File "/home/ubuntu/code/nnUNet/nnunet/network_architecture/generic_UNet.py", line 400
            # module: ModuleInterface = self.conv_blocks_context[d]
            # x = module(x)
            x = self.conv_blocks_context[d](x)
                ~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            skips.append(x)
            if not self.convolutional_pooling:

I dug a bit deeper but was not successful to make it work. Has anyone tried the same? Thanks.

dojoh commented 1 year ago

Hello, sorry for the late response. Its this question still relevant or did you manage to solve it?