pytorch / executorch

On-device AI across mobile, embedded and edge for PyTorch
https://pytorch.org/executorch/
Other
2.18k stars 358 forks source link

GroupNorm fails export to Core ML #6817

Open jakmro opened 2 days ago

jakmro commented 2 days ago

🐛 Describe the bug

Description

I tried to export a PyTorch model using torch.nn.GroupNorm to Core ML with ExecuTorch, but encountered a ValueError: Unsupported fx node aten_native_group_norm_default, kind native_group_norm.

Traceback

Traceback (most recent call last):
  File "/Users/jakubmroz/miniconda3/envs/executorchissue/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Users/jakubmroz/miniconda3/envs/executorchissue/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/Users/jakubmroz/Desktop/issue/executorch/examples/apple/coreml/scripts/export.py", line 205, in <module>
    main()
  File "/Users/jakubmroz/Desktop/issue/executorch/examples/apple/coreml/scripts/export.py", line 184, in main
    lowered_module, edge_copy = lower_module_to_coreml(
  File "/Users/jakubmroz/Desktop/issue/executorch/examples/apple/coreml/scripts/export.py", line 98, in lower_module_to_coreml
    lowered_module = to_backend(
  File "/Users/jakubmroz/miniconda3/envs/executorchissue/lib/python3.10/functools.py", line 878, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/Users/jakubmroz/miniconda3/envs/executorchissue/lib/python3.10/site-packages/executorch/exir/backend/backend_api.py", line 113, in _
    preprocess_result: PreprocessResult = cls.preprocess(
  File "/Users/jakubmroz/miniconda3/envs/executorchissue/lib/python3.10/site-packages/executorch/backends/apple/coreml/compiler/coreml_preprocess.py", line 431, in preprocess
    mlmodel = ct.convert(
  File "/Users/jakubmroz/miniconda3/envs/executorchissue/lib/python3.10/site-packages/coremltools/converters/_converters_entry.py", line 635, in convert
    mlmodel = mil_convert(
  File "/Users/jakubmroz/miniconda3/envs/executorchissue/lib/python3.10/site-packages/coremltools/converters/mil/converter.py", line 188, in mil_convert
    return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
  File "/Users/jakubmroz/miniconda3/envs/executorchissue/lib/python3.10/site-packages/coremltools/converters/mil/converter.py", line 212, in _mil_convert
    proto, mil_program = mil_convert_to_proto(
  File "/Users/jakubmroz/miniconda3/envs/executorchissue/lib/python3.10/site-packages/coremltools/converters/mil/converter.py", line 288, in mil_convert_to_proto
    prog = frontend_converter(model, **kwargs)
  File "/Users/jakubmroz/miniconda3/envs/executorchissue/lib/python3.10/site-packages/coremltools/converters/mil/converter.py", line 108, in __call__
    return load(*args, **kwargs)
  File "/Users/jakubmroz/miniconda3/envs/executorchissue/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 81, in load
    converter = TorchConverter(
  File "/Users/jakubmroz/miniconda3/envs/executorchissue/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 618, in __init__
    self.graph = InternalTorchIRGraph.from_exir(exir=loaded_model)
  File "/Users/jakubmroz/miniconda3/envs/executorchissue/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/internal_graph.py", line 574, in from_exir
    nodes.append(InternalTorchIRNode.from_exir_node(node=node))
  File "/Users/jakubmroz/miniconda3/envs/executorchissue/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/internal_graph.py", line 285, in from_exir_node
    raise ValueError(f"Unsupported fx node {str(node)}, kind {kind}")
ValueError: Unsupported fx node aten_native_group_norm_default, kind native_group_norm

Minimal Example

To reproduce this error, modify the following lines in examples/apple/coreml/scripts/export.py

Replace:

    model, example_inputs, _, _ = EagerModelFactory.create_model(
        *MODEL_NAME_TO_MODEL[args.model_name]
    )

With:

    model = torch.nn.GroupNorm(3, 6)
    example_inputs = (torch.randn(20, 6, 10, 10),)

Then execute the script using:

python3 -m examples.apple.coreml.scripts.export --model_name mv3

Workaround

For now, I "hacked" it and replaced torch.nn.GroupNorm with a custom implementation, which can be exported.

class CustomGroupNorm(torch.nn.Module):
     def __init__(self, num_groups, num_channels, eps=1e-5, affine=True):
         """Custom Group Normalization."""
         super(CustomGroupNorm, self).__init__()
         self.num_groups = num_groups
         self.num_channels = num_channels
         self.eps = eps
         self.affine = affine

         if self.affine:
             self.weight = torch.nn.Parameter(torch.ones(num_channels))
             self.bias = torch.nn.Parameter(torch.zeros(num_channels))

     def forward(self, x):
         N, C, *rest = x.shape
         G = self.num_groups
         x = x.view(N, G, C // G, *rest)
         shape = (2, *range(3, x.dim()))
         mean = x.mean(dim=shape, keepdim=True)
         var = ((x - mean) ** 2).mean(dim=shape, keepdim=True)
         x = (x - mean) / (var + self.eps).sqrt()
         x = x.view(N, C, *rest)
         shape = (1, -1, *([1] * len(rest)))
         if self.affine:
             x = x * self.weight.view(shape) + self.bias.view(shape)
         return x
    model = CustomGroupNorm(3, 6)
    example_inputs = (torch.randn(20, 6, 10, 10),)

Versions

Collecting environment information... PyTorch version: 2.6.0.dev20241101 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: macOS 14.5 (arm64) GCC version: Could not collect Clang version: 16.0.0 (clang-1600.0.26.3) CMake version: version 3.31.0 Libc version: N/A

Python version: 3.10.0 (default, Mar 3 2022, 03:54:28) [Clang 12.0.0 ] (64-bit runtime) Python platform: macOS-14.5-arm64-arm-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Apple M1 Pro

Versions of relevant libraries: [pip3] executorch==0.5.0a0+dc41596 [pip3] executorchcoreml==0.0.1 [pip3] numpy==1.26.4 [pip3] torch==2.6.0.dev20241101 [pip3] torchaudio==2.5.0.dev20241101 [pip3] torchsr==1.0.4 [pip3] torchvision==0.20.0.dev20241101 [conda] executorch 0.5.0a0+dc41596 pypi_0 pypi [conda] executorchcoreml 0.0.1 pypi_0 pypi [conda] numpy 1.26.4 pypi_0 pypi [conda] torch 2.6.0.dev20241101 pypi_0 pypi [conda] torchaudio 2.5.0.dev20241101 pypi_0 pypi [conda] torchsr 1.0.4 pypi_0 pypi [conda] torchvision 0.20.0.dev20241101 pypi_0 pypi

GregoryComer commented 1 day ago

@cccclai Do you have any suggestions here? Looks like it might be a bug that we need to report to the coreml team?