Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.86k stars 1.09k forks source link

Cannot export SwinUNETR model as TorchScript #6823

Closed haoqiwang1 closed 10 months ago

haoqiwang1 commented 1 year ago

Describe the bug I tried to train a model with SwinUNETR and export the model as TorchScript for later use in monai deploy packaging, but it fails export the model.

To Reproduce

from monai.networks.nets import SwinUNETR
import torch
device = torch.device("cpu")

model = SwinUNETR(
    img_size=(96, 96, 96),
    in_channels=1,
    out_channels=14,
    feature_size=48,
    use_checkpoint=True,
).to(device)
  1. Try to use trace
    traced = torch.jit.trace(model, torch.rand(1, 96, 96, 96))
/opt/conda/lib/python3.10/site-packages/monai/networks/blocks/patchembedding.py:186: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if w % self.patch_size[1] != 0:
/opt/conda/lib/python3.10/site-packages/monai/networks/blocks/patchembedding.py:188: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if h % self.patch_size[0] != 0:
/opt/conda/lib/python3.10/site-packages/monai/networks/nets/swin_unetr.py:397: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if x_size[i] <= window_size[i]:
/opt/conda/lib/python3.10/site-packages/monai/networks/nets/swin_unetr.py:889: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  hp = int(np.ceil(h / window_size[0])) * window_size[0]
/opt/conda/lib/python3.10/site-packages/monai/networks/nets/swin_unetr.py:890: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  wp = int(np.ceil(w / window_size[1])) * window_size[1]
/opt/conda/lib/python3.10/site-packages/monai/networks/nets/swin_unetr.py:627: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if pad_r > 0 or pad_b > 0:
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[3], line 1
----> 1 traced = torch.jit.trace(model, torch.rand(1, 96, 96, 96))

File /opt/conda/lib/python3.10/site-packages/torch/jit/_trace.py:759, in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
    756     return func
    758 if isinstance(func, torch.nn.Module):
--> 759     return trace_module(
    760         func,
    761         {"forward": example_inputs},
    762         None,
    763         check_trace,
    764         wrap_check_inputs(check_inputs),
    765         check_tolerance,
    766         strict,
    767         _force_outplace,
    768         _module_class,
    769     )
    771 if (
    772     hasattr(func, "__self__")
    773     and isinstance(func.__self__, torch.nn.Module)
    774     and func.__name__ == "forward"
    775 ):
    776     return trace_module(
    777         func.__self__,
    778         {"forward": example_inputs},
   (...)
    785         _module_class,
    786     )

File /opt/conda/lib/python3.10/site-packages/torch/jit/_trace.py:976, in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
    972     argument_names = get_callable_argument_names(func)
    974 example_inputs = make_tuple(example_inputs)
--> 976 module._c._create_method_from_trace(
    977     method_name,
    978     func,
    979     example_inputs,
    980     var_lookup_fn,
    981     strict,
    982     _force_outplace,
    983     argument_names,
    984 )
    985 check_trace_method = module._c._get_method(method_name)
    987 # Check the trace against new traces created from user-specified inputs

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1182, in Module._slow_forward(self, *input, **kwargs)
   1180         recording_scopes = False
   1181 try:
-> 1182     result = self.forward(*input, **kwargs)
   1183 finally:
   1184     if recording_scopes:

File /opt/conda/lib/python3.10/site-packages/monai/networks/nets/swin_unetr.py:301, in SwinUNETR.forward(self, x_in)
    300 def forward(self, x_in):
--> 301     hidden_states_out = self.swinViT(x_in, self.normalize)
    302     enc0 = self.encoder1(x_in)
    303     enc1 = self.encoder2(hidden_states_out[0])

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1182, in Module._slow_forward(self, *input, **kwargs)
   1180         recording_scopes = False
   1181 try:
-> 1182     result = self.forward(*input, **kwargs)
   1183 finally:
   1184     if recording_scopes:

File /opt/conda/lib/python3.10/site-packages/monai/networks/nets/swin_unetr.py:1043, in SwinTransformer.forward(self, x, normalize)
   1041 if self.use_v2:
   1042     x0 = self.layers1c[0](x0.contiguous())
-> 1043 x1 = self.layers1[0](x0.contiguous())
   1044 x1_out = self.proj_out(x1, normalize)
   1045 if self.use_v2:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1182, in Module._slow_forward(self, *input, **kwargs)
   1180         recording_scopes = False
   1181 try:
-> 1182     result = self.forward(*input, **kwargs)
   1183 finally:
   1184     if recording_scopes:

File /opt/conda/lib/python3.10/site-packages/monai/networks/nets/swin_unetr.py:893, in BasicLayer.forward(self, x)
    891 attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device)
    892 for blk in self.blocks:
--> 893     x = blk(x, attn_mask)
    894 x = x.view(b, h, w, -1)
    895 if self.downsample is not None:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1182, in Module._slow_forward(self, *input, **kwargs)
   1180         recording_scopes = False
   1181 try:
-> 1182     result = self.forward(*input, **kwargs)
   1183 finally:
   1184     if recording_scopes:

File /opt/conda/lib/python3.10/site-packages/monai/networks/nets/swin_unetr.py:672, in SwinTransformerBlock.forward(self, x, mask_matrix)
    670 shortcut = x
    671 if self.use_checkpoint:
--> 672     x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
    673 else:
    674     x = self.forward_part1(x, mask_matrix)

File /opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py:249, in checkpoint(function, use_reentrant, *args, **kwargs)
    246     raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
    248 if use_reentrant:
--> 249     return CheckpointFunction.apply(function, preserve, *args)
    250 else:
    251     return _checkpoint_without_reentrant(
    252         function,
    253         preserve,
    254         *args,
    255         **kwargs,
    256     )

RuntimeError: _Map_base::at
  1. Try to use script
    model_scripted = torch.jit.script(model)
---------------------------------------------------------------------------
NotSupportedError                         Traceback (most recent call last)
Cell In[4], line 1
----> 1 model_scripted = torch.jit.script(model) # Export to TorchScript

File /opt/conda/lib/python3.10/site-packages/torch/jit/_script.py:1286, in script(obj, optimize, _frames_up, _rcb, example_inputs)
   1284 if isinstance(obj, torch.nn.Module):
   1285     obj = call_prepare_scriptable_func(obj)
-> 1286     return torch.jit._recursive.create_script_module(
   1287         obj, torch.jit._recursive.infer_methods_to_compile
   1288     )
   1290 if isinstance(obj, dict):
   1291     return create_script_dict(obj)

File /opt/conda/lib/python3.10/site-packages/torch/jit/_recursive.py:476, in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
    474 if not is_tracing:
    475     AttributeTypeIsSupportedChecker().check(nn_module)
--> 476 return create_script_module_impl(nn_module, concrete_type, stubs_fn)

File /opt/conda/lib/python3.10/site-packages/torch/jit/_recursive.py:538, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    535     script_module._concrete_type = concrete_type
    537 # Actually create the ScriptModule, initializing it with the function we just defined
--> 538 script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
    540 # Compile methods if necessary
    541 if concrete_type not in concrete_type_store.methods_compiled:

File /opt/conda/lib/python3.10/site-packages/torch/jit/_script.py:615, in RecursiveScriptModule._construct(cpp_module, init_fn)
    602 """
    603 Construct a RecursiveScriptModule that's ready for use. PyTorch
    604 code should use this to construct a RecursiveScriptModule instead
   (...)
    612     init_fn:  Lambda that initializes the RecursiveScriptModule passed to it.
    613 """
    614 script_module = RecursiveScriptModule(cpp_module)
--> 615 init_fn(script_module)
    617 # Finalize the ScriptModule: replace the nn.Module state with our
    618 # custom implementations and flip the _initializing bit.
    619 RecursiveScriptModule._finalize_scriptmodule(script_module)

File /opt/conda/lib/python3.10/site-packages/torch/jit/_recursive.py:516, in create_script_module_impl.<locals>.init_fn(script_module)
    513     scripted = orig_value
    514 else:
    515     # always reuse the provided stubs_fn to infer the methods to compile
--> 516     scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
    518 cpp_module.setattr(name, scripted)
    519 script_module._modules[name] = scripted

File /opt/conda/lib/python3.10/site-packages/torch/jit/_recursive.py:538, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    535     script_module._concrete_type = concrete_type
    537 # Actually create the ScriptModule, initializing it with the function we just defined
--> 538 script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
    540 # Compile methods if necessary
    541 if concrete_type not in concrete_type_store.methods_compiled:

File /opt/conda/lib/python3.10/site-packages/torch/jit/_script.py:615, in RecursiveScriptModule._construct(cpp_module, init_fn)
    602 """
    603 Construct a RecursiveScriptModule that's ready for use. PyTorch
    604 code should use this to construct a RecursiveScriptModule instead
   (...)
    612     init_fn:  Lambda that initializes the RecursiveScriptModule passed to it.
    613 """
    614 script_module = RecursiveScriptModule(cpp_module)
--> 615 init_fn(script_module)
    617 # Finalize the ScriptModule: replace the nn.Module state with our
    618 # custom implementations and flip the _initializing bit.
    619 RecursiveScriptModule._finalize_scriptmodule(script_module)

File /opt/conda/lib/python3.10/site-packages/torch/jit/_recursive.py:516, in create_script_module_impl.<locals>.init_fn(script_module)
    513     scripted = orig_value
    514 else:
    515     # always reuse the provided stubs_fn to infer the methods to compile
--> 516     scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
    518 cpp_module.setattr(name, scripted)
    519 script_module._modules[name] = scripted

    [... skipping similar frames: RecursiveScriptModule._construct at line 615 (2 times), create_script_module_impl at line 538 (2 times), create_script_module_impl.<locals>.init_fn at line 516 (2 times)]

File /opt/conda/lib/python3.10/site-packages/torch/jit/_recursive.py:538, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    535     script_module._concrete_type = concrete_type
    537 # Actually create the ScriptModule, initializing it with the function we just defined
--> 538 script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
    540 # Compile methods if necessary
    541 if concrete_type not in concrete_type_store.methods_compiled:

File /opt/conda/lib/python3.10/site-packages/torch/jit/_script.py:615, in RecursiveScriptModule._construct(cpp_module, init_fn)
    602 """
    603 Construct a RecursiveScriptModule that's ready for use. PyTorch
    604 code should use this to construct a RecursiveScriptModule instead
   (...)
    612     init_fn:  Lambda that initializes the RecursiveScriptModule passed to it.
    613 """
    614 script_module = RecursiveScriptModule(cpp_module)
--> 615 init_fn(script_module)
    617 # Finalize the ScriptModule: replace the nn.Module state with our
    618 # custom implementations and flip the _initializing bit.
    619 RecursiveScriptModule._finalize_scriptmodule(script_module)

File /opt/conda/lib/python3.10/site-packages/torch/jit/_recursive.py:516, in create_script_module_impl.<locals>.init_fn(script_module)
    513     scripted = orig_value
    514 else:
    515     # always reuse the provided stubs_fn to infer the methods to compile
--> 516     scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
    518 cpp_module.setattr(name, scripted)
    519 script_module._modules[name] = scripted

File /opt/conda/lib/python3.10/site-packages/torch/jit/_recursive.py:542, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    540 # Compile methods if necessary
    541 if concrete_type not in concrete_type_store.methods_compiled:
--> 542     create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    543     # Create hooks after methods to ensure no name collisions between hooks and methods.
    544     # If done before, hooks can overshadow methods that aren't exported.
    545     create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs)

File /opt/conda/lib/python3.10/site-packages/torch/jit/_recursive.py:393, in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    390 property_defs = [p.def_ for p in property_stubs]
    391 property_rcbs = [p.resolution_callback for p in property_stubs]
--> 393 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)

File /opt/conda/lib/python3.10/site-packages/torch/jit/_recursive.py:863, in try_compile_fn(fn, loc)
    859 # We don't have the actual scope where the function was defined, but we can
    860 # extract the necessary info from the closed over variables on the function
    861 # object
    862 rcb = _jit_internal.createResolutionCallbackFromClosure(fn)
--> 863 return torch.jit.script(fn, _rcb=rcb)

File /opt/conda/lib/python3.10/site-packages/torch/jit/_script.py:1340, in script(obj, optimize, _frames_up, _rcb, example_inputs)
   1338 if maybe_already_compiled_fn:
   1339     return maybe_already_compiled_fn
-> 1340 ast = get_jit_def(obj, obj.__name__)
   1341 if _rcb is None:
   1342     _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)

File /opt/conda/lib/python3.10/site-packages/torch/jit/frontend.py:293, in get_jit_def(fn, def_name, self_name, is_classmethod)
    290     qualname = get_qualified_name(fn)
    291     pdt_arg_types = type_trace_db.get_args_types(qualname)
--> 293 return build_def(parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)

File /opt/conda/lib/python3.10/site-packages/torch/jit/frontend.py:331, in build_def(ctx, py_def, type_line, def_name, self_name, pdt_arg_types)
    326 body = py_def.body
    327 r = ctx.make_range(py_def.lineno + len(py_def.decorator_list),
    328                    py_def.col_offset,
    329                    py_def.col_offset + len("def"))
--> 331 param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
    332 return_type = None
    333 if getattr(py_def, 'returns', None) is not None:

File /opt/conda/lib/python3.10/site-packages/torch/jit/frontend.py:355, in build_param_list(ctx, py_args, self_name, pdt_arg_types)
    353     expr = py_args.kwarg
    354     ctx_range = ctx.make_range(expr.lineno, expr.col_offset - 1, expr.col_offset + len(expr.arg))
--> 355     raise NotSupportedError(ctx_range, _vararg_kwarg_err)
    356 if py_args.vararg is not None:
    357     expr = py_args.vararg

NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 164
def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
                                                             ~~~~~~~ <--- HERE
    r"""Checkpoint a model or part of the model

Environment

Ensuring you use the relevant python executable, please paste the output of:

python -c 'import monai; monai.config.print_debug_info()'
================================
Printing MONAI config...
================================
MONAI version: 1.3.dev2330
Numpy version: 1.23.5
Pytorch version: 1.13.1+cu117
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: ce84232bbde7d4c31b274f61c8ccef539c421aaa
MONAI __file__: /opt/conda/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.1.0
scikit-image version: 0.21.0
Pillow version: 9.5.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: 4.7.1
TorchVision version: 0.14.1+cu117
tqdm version: 4.65.0
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.3
pandas version: 1.5.3
einops version: 0.6.1
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

================================
Printing system config...
================================
System: Linux
Linux version: Debian GNU/Linux 11 (bullseye)
Platform: Linux-5.10.0-23-cloud-amd64-x86_64-with-glibc2.31
Processor: 
Machine: x86_64
Python version: 3.10.11
Process name: python
Command: ['python', '-c', 'import monai; monai.config.print_debug_info()']
Open files: []
Num physical CPUs: 2
Num logical CPUs: 4
Num usable CPUs: 4
CPU usage (%): [6.2, 6.2, 4.8, 100.0]
CPU freq. (MHz): 2300
Load avg. in last 1, 5, 15 mins (%): [14.2, 5.5, 2.2]
Disk usage (%): 11.9
Avg. sensor temp. (Celsius): UNKNOWN for given OS
Total physical memory (GB): 14.6
Available memory (GB): 12.8
Used memory (GB): 1.6

================================
Printing GPU config...
================================
Num GPUs: 1
Has CUDA: True
CUDA version: 11.7
cuDNN enabled: True
cuDNN version: 8500
Current device: 0
Library compiled for CUDA architectures: ['sm_37', 'sm_50', 'sm_60', 'sm_70', 'sm_75', 'sm_80', 'sm_86']
GPU 0 Name: Tesla T4
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 40
GPU 0 Total memory (GB): 14.6
GPU 0 CUDA capability (maj.min): 7.5

Thanks for your help!

KumoLiu commented 1 year ago

Hi @haoqiwang1, seems it's a known issue. You can disable use_checkpoint param to support torchscript conversion by torch.jit.trace. https://github.com/Project-MONAI/model-zoo/pull/316

Thanks!

vikashg commented 10 months ago

closing because alternatives exist and inactivity.