pytorch / executorch

On-device AI across mobile, embedded and edge for PyTorch
https://pytorch.org/executorch/
Other
2.17k stars 358 forks source link

dl3 model fails to export with coreml delegates #5159

Open guangy10 opened 2 months ago

guangy10 commented 2 months ago

🐛 Describe the bug

dl3 is claimed supported by coreml delegates here: https://github.com/pytorch/executorch/tree/main/examples/apple/coreml#frequently-encountered-errors-and-resolution, however, it will fail on export. python3 -m examples.apple.coreml.scripts.export --model_name dl3

Converting PyTorch Frontend ==> MIL Ops: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▋| 374/375 [00:00<00:00, 3900.71 ops/s]
Running MIL frontend_pytorch pipeline: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 121.65 passes/s]
Running MIL default pipeline: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:02<00:00, 39.44 passes/s]
Running MIL backend_mlprogram pipeline: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 225.17 passes/s]
E0906 16:24:40.162000 73755 torch/export/_trace.py:999] See unsupported_operator in exportdb for unsupported case.                 https://pytorch.org/docs/main/generated/exportdb/index.html#unsupported-operator
Traceback (most recent call last):
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/Users/guangyang/executorch/examples/apple/coreml/scripts/export.py", line 184, in <module>
    exec_program = export_lowered_module_to_executorch_program(
  File "/Users/guangyang/executorch/examples/apple/coreml/scripts/export.py", line 106, in export_lowered_module_to_executorch_program
    export(lowered_module, example_inputs), compile_config=_EDGE_COMPILE_CONFIG
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/export/__init__.py", line 258, in export
    return _export(
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/export/_trace.py", line 1013, in wrapper
    raise e
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/export/_trace.py", line 986, in wrapper
    ep = fn(*args, **kwargs)
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/export/exported_program.py", line 97, in wrapper
    return fn(*args, **kwargs)
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/export/_trace.py", line 1921, in _export
    export_artifact = export_func(  # type: ignore[operator]
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/export/_trace.py", line 1220, in _strict_export
    return _strict_export_lower_to_aten_ir(
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/export/_trace.py", line 1248, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/export/_trace.py", line 556, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1432, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1244, in __call__
    return self._torchdynamo_orig_callable(
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 516, in __call__
    return _compile(
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 908, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 656, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 689, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 210, in _fn
    return fn(*args, **kwargs)
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 624, in transform
    tracer.run()
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2784, in run
    super().run()
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1668, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_dynamo/variables/higher_order_ops.py", line 1247, in call_function
    return wrap_fx_proxy(
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1950, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 2245, in wrap_fx_proxy_cls
    unimplemented(
  File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 289, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor OrderedDict call_function executorch_call_delegate

from user code:
   File "/Users/guangyang/miniconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/lowered_backend_module.py", line 343, in forward
    return executorch_call_delegate(self, *args)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Versions

PyTorch version: 2.5.0.dev20240829 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: macOS 14.6.1 (arm64) GCC version: Could not collect Clang version: 15.0.0 (clang-1500.3.9.4) CMake version: version 3.29.0 Libc version: N/A

Python version: 3.10.13 (main, Sep 11 2023, 08:16:02) [Clang 14.0.6 ] (64-bit runtime) Python platform: macOS-14.6.1-arm64-arm-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Apple M1 Max

Versions of relevant libraries: [pip3] executorch==0.4.0a0+52c9f30 [pip3] executorchcoreml==0.0.1 [pip3] flake8==6.0.0 [pip3] flake8-breakpoint==1.1.0 [pip3] flake8-bugbear==23.6.5 [pip3] flake8-comprehensions==3.12.0 [pip3] flake8-plugin-utils==1.3.3 [pip3] flake8-pyi==23.5.0 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.21.3 [pip3] pytorch-labs-segment-anything-fast==0.2 [pip3] torch==2.5.0.dev20240829 [pip3] torchaudio==2.5.0.dev20240829 [pip3] torchsr==1.0.4 [pip3] torchvision==0.20.0.dev20240829 [conda] executorch 0.4.0a0+52c9f30 pypi_0 pypi [conda] executorchcoreml 0.0.1 pypi_0 pypi [conda] numpy 1.21.3 pypi_0 pypi [conda] pytorch-labs-segment-anything-fast 0.2 pypi_0 pypi [conda] torch 2.5.0.dev20240829 pypi_0 pypi [conda] torchaudio 2.5.0.dev20240829 pypi_0 pypi [conda] torchfix 0.1.1 pypi_0 pypi [conda] torchsr 1.0.4 pypi_0 pypi [conda] torchvision 0.20.0.dev20240829 pypi_0 pypi

YifanShenSZ commented 1 month ago

Hi @guangy10, the error message seems to be in torch

torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor OrderedDict call_function executorch_call_delegate

i.e. something went wrong with executorch_call_delegate

I confirmed the coreml model was generated correctly