Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.86k stars 1.08k forks source link

Exporting SwinUNETR to ONNX does not work #5125

Closed nicktasios closed 2 years ago

nicktasios commented 2 years ago

Describe the bug When trying to export the SwinUNETR model from MONAI, I get the error:

RuntimeError: Failed to export an ONNX attribute 'onnx::Gather', since it's not constant, please try to make things (e.g., kernel size) static if possible.

In a different issue, I read that this issue might get fixed by changing x_shape = x.size() to x_shape = [int(s) for s in x.size()] in the problematic code -- I found out that problem manifests at proj_out(). Doing this, though, results in a different error:

RuntimeError: Unsupported: ONNX export of instance_norm for unknown channel size.

Making this change in all places where I find x_shape = x.size() results in a floating point exception!

To Reproduce

Here is a minimal example demonstrating the issue:

from monai.networks.nets import SwinUNETR    
import torch                                 

if __name__ == '__main__':                   

    model = SwinUNETR(img_size=(96, 96, 96), 
                      in_channels=1,         
                      out_channels=5,        
                      feature_size=48,       
                      drop_rate=0.0,         
                      attn_drop_rate=0.0,    
                      dropout_path_rate=0.0, 
                      use_checkpoint=True,   
                      )                      
    inputs = [torch.randn([1,1,96,96,96])]
    input_names = ['input']                          
    output_names = ['output']                        

    torch.onnx.export(                               
        model,                                       
        tuple(inputs), 'model.onnx',                 
        verbose=False,                               
        input_names=input_names,                     
        output_names=output_names,                   
        dynamic_axes=None,                           
        opset_version=11,                            
    )                                                

Environment

================================
MONAI version: 0.9.1
Numpy version: 1.23.2
Pytorch version: 1.12.1.post200
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 356d2d2f41b473f588899d705bbc682308cee52c
MONAI file: .../envs/temp_env/lib/python3.10/site-packages/monai/init.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: NOT INSTALLED or UNKNOWN VERSION.
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
Pillow version: NOT INSTALLED or UNKNOWN VERSION.
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: NOT INSTALLED or UNKNOWN VERSION.
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: NOT INSTALLED or UNKNOWN VERSION.
pandas version: NOT INSTALLED or UNKNOWN VERSION.
einops version: 0.4.1
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

================================
Printing system config...
================================
psutil required for print_system_info

================================
Printing GPU config...
================================
Num GPUs: 1
Has CUDA: True
CUDA version: 11.2
cuDNN enabled: True
cuDNN version: 8401
Current device: 0
Library compiled for CUDA architectures: ['sm_35', 'sm_50', 'sm_60', 'sm_61', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'compute_86'] GPU 0 Name: Tesla T4
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 40
GPU 0 Total memory (GB): 14.6
GPU 0 CUDA capability (maj.min): 7.5

Additional context I have also filed an issue with Pytorch as I'm not certain on which side should the bug be resolved.

wyli commented 2 years ago

Could you please elaborate the float exception? Indeed we have a recent patch https://github.com/Project-MONAI/MONAI/pull/4913 not sure if it’s related.

nicktasios commented 2 years ago

@wyli I tried the change in that patch and that resulted in the same problem. The problem exception happens in the following:

Thread 1 "python" received signal SIGFPE, Arithmetic exception.                                                                                                  
0x00002aaae8980949 in torch::jit::(anonymous namespace)::ComputeShapeFromReshape(torch::jit::Node*, c10::SymbolicShape const&, c10::SymbolicShape const&, int) ()
   from /home/014118_emtic_oncology/Pancreas/nick/envs/temp_env/lib/python3.10/site-packages/torch/lib/libtorch_python.so                                        
wyli commented 2 years ago

it seems by using x_shape = [int(s) for s in x.size()] as you mentioned in the pytorch thread partly addresses the issue.

it's possible to export with these parameter changes:

    model = SwinUNETR(img_size=(96, 96, 96), 
                      in_channels=1,         
                      out_channels=5,        
                      feature_size=48,
+                     norm_name=("instance", {"affine": True}),
                      drop_rate=0.0,         
                      attn_drop_rate=0.0,    
                      dropout_path_rate=0.0, 
-                     use_checkpoint=True,
+                     use_checkpoint=False,
                      )               

I don't think use_checkpoint=True gradient checkpointing could be easily supported here.

nicktasios commented 2 years ago

@wyli I tried the changes you suggested and indeed, the model was successfully exported to onnx. Unfortunately, during inference I get the following error:

[E:onnxruntime:ONNX_ENGINE, tensorrt_execution_provider.h:51 log] [2022-10-04 14:23:43   ERROR] 10: [optimizer.cpp::computeCosts::2011] Error Code 10: Internal Error (Could not find any implementation for node {ForeignNode[Reshape_1863 + Transpose_1864...Add_1962]}.)
Exception caught: TensorRT EP could not build engine for fused node: TensorrtExecutionProvider_TRTKernel_graph_torch-jit-export_18064542951413184515_8_1
terminate called after throwing an instance of 'Ort::Exception'
  what():  TensorRT EP could not build engine for fused node: TensorrtExecutionProvider_TRTKernel_graph_torch-jit-export_18064542951413184515_8_1
tangy5 commented 1 year ago

Hi @wyli , if you happen to see this in this closed issue. Now the Swin UNETR can safely convert to torchscript by using torch.onnx.export or torch.jit.trace. But it cannot be converted by torch.jit.script since it contains lot of dynamic programming with python, including use_checkpoint=True gradient checkpointing which is not supported.

Do you have any idea on whether we need to make Swin UNETR model scriptable by"torch.jit.script" ? Since "torch.jit.trace" and "ONXX" has some limitations such as it do not support branching.

Thank you.

wyli commented 1 year ago

I looked into that some time ago, and torch.jit.script support is not easy for this model, perhaps we don't spend more time on this for now. If it's really needed we may have to write a less flexible version of the model and hard-code some hyperparameters.

tangy5 commented 1 year ago

I looked into that some time ago, and torch.jit.script support is not easy for this model, perhaps we don't spend more time on this for now. If it's really needed we may have to write a less flexible version of the model and hard-code some hyperparameters.

Thanks, I agree, I was trying to rewrite this model these days to support torch.jit.script, but stuck with the use_checkpoint option. Maybe we remain it traceable torch.jit.trace for now. If we need to support specific TensorRT in the future, we can write a another version of this model.

Nic-Ma commented 1 year ago

Hi @wyli @tangy5 ,

I think TorchScript and TensorRT support are "must-have" for the next versions release? Let's discuss it later.

Thanks.

tangy5 commented 1 year ago

Hi @wyli @tangy5 ,

I think TorchScript and TensorRT support are "must-have" for the next versions release? Let's discuss it later.

Thanks.

Sure, the point is whether we need torch.jit.script for torchScript model, or torch.jit.trace and ONXX are good enough. If the torch.jit.script is a must have, I suggest we can write a light version of this model (e.g., swinunetr_lt), and remove dynamic programmed sections. But we might have other options...Let's discuss. Thank you.

Nic-Ma commented 1 year ago

I think the TorchScript -> TensorRT is recommended way now, instead of the previous ONNX -> TensorRT. And here is the example: https://github.com/pytorch/TensorRT/blob/main/README.md?plain=1#L88 I think we may need to connect TensorRT team for more details later, CC @deepib .

Thanks.

wyli commented 1 year ago

also, if it's for inference only, the option of using gradient checkpointing is not needed

csheaff commented 1 year ago

Hello @tangy5 @wyli and others, I have tried and failed to use the above mentioned workarounds to convert SwinUNETR to TensorRT format. I'm using monai version 1.2.0rc3. I am attempting two routes to TensorRT: (1) through torch_tensorrt.ts.compile after torch.jit.trace and (2) using torch.onnx.export. Both result in errors shown below. It seems that torch.jit.trace does not successfully trace the graph, but perhaps I'm doing something wrong here. Any help would be much appreciated. Note I am working with a model saved using nn.parallel.DataParallel, hence the logic. Also, I am not using a model trained with norm_name=("instance", {"affine": True}), so it doesn't look like I can include that statement and load it.

 model = SwinUNETR(
        img_size=(128, 128, 128),
        in_channels=1,
        out_channels=n_classes,
        feature_size=n_features,
        use_checkpoint=False,
    )
    load_w_data_parallel = False
    if load_w_data_parallel:
        model = torch.nn.parallel.DataParallel(model)
        model.load_state_dict(torch.load(model_path))
    else:
        # https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686
        state_dict = torch.load(model_path)
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k.replace("module.", "")
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)

    traced_model = torch.jit.trace(model, torch.rand(1, 1, 128, 128, 128))

    trt_ts_model = torch_tensorrt.ts.compile(traced_model, inputs=torch.rand(1, 1, 128, 128, 128), enabled_precisions={torch.float, torch.half})

    input_names = ['input']
    output_names = ['output']
    torch.onnx.export(
        model,
        torch.rand(1, 1, 128, 128, 128),
        'model.onnx',
        verbose=False,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=None,
        opset_version=11,
    )

The torch_tensorrt.ts.compile commmand produces the following error:

*** RuntimeError: 0 INTERNAL ASSERT FAILED at "/opt/pytorch/pytorch/torch/csrc/jit/ir/alias_analysis.cpp":615, please report a bug to PyTorch. We don't have an op for aten::constant_pad_nd but it isn't a special case. Argument types: Tensor, int[], NoneType,

Candidates: aten::constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor aten::constant_pad_nd.out(Tensor self, SymInt[] pad, Scalar value=0, *, Tensor(a!) out) -> Tensor(a!)

The torch.onnx.export command yeilds another error:

*** torch.onnx.errors.SymbolicValueError: Failed to export a node '%6407 : Long(device=cpu) = onnx::Gather[axis=0](%6404, %6406), scope: monai.networks.nets.swin_unetr.SwinUNETR::/monai.networks.nets.swin_unetr.SwinTransformer::swinViT # /opt/monai/monai/networks/nets/swin_unetr.py:1024:0 ' (in list node %6444 : int[] = prim::ListConstruct(%6407), scope: monai.networks.nets.swin_unetr.SwinUNETR::/monai.networks.nets.swin_unetr.SwinTransformer::swinViT ) because it is not constant. Please try to make things (e.g. kernel sizes) static if possible. [Caused by the value '6444 defined in (%6444 : int[] = prim::ListConstruct(%6407), scope: monai.networks.nets.swin_unetr.SwinUNETR::/monai.networks.nets.swin_unetr.SwinTransformer::swinViT )' (type 'List[int]') in the TorchScript graph. The containing node has kind 'prim::ListConstruct'.]

Inputs:
    #0: 6407 defined in (%6407 : Long(device=cpu) = onnx::Gather[axis=0](%6404, %6406), scope: monai.networks.nets.swin_unetr.SwinUNETR::/monai.networks.nets.swin_unetr.SwinTransformer::swinViT # /opt/monai/monai/networks/nets/swin_unetr.py:1024:0
)  (type 'Tensor')
Outputs:
    #0: 6444 defined in (%6444 : int[] = prim::ListConstruct(%6407), scope: monai.networks.nets.swin_unetr.SwinUNETR::/monai.networks.nets.swin_unetr.SwinTransformer::swinViT
)  (type 'List[int]'