apple / coremltools

Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.
https://coremltools.readme.io
BSD 3-Clause "New" or "Revised" License
4.46k stars 648 forks source link

Error when converting a PyTorch traced model to “neuralnetwork” or “mlprogram" format #2112

Open andreascuderi opened 10 months ago

andreascuderi commented 10 months ago

🐞Describing the bug

I’m trying to create a coreml package from a traced model, but I get the following error when calling coremltools.convert:

ValueError: Op "135" (op_type: slice_by_index) Input x="130" expects tensor or scalar of dtype from type domain ['fp16', 'fp32', 'int32', 'bool'] but got tensor[1,2,2049,440,complex64]

Stack Trace

Converting PyTorch Frontend ==> MIL Ops: 5%| | 91/1722 [00:00<00:00, 3237.72 o Traceback (most recent call last): File "/Users/andreascuderi/nTrack/trunk/n-Track_EX/Scripts/./hybrid_demucs_tracer.py", line 72, in mlpackage_obj = ct.convert( ^^^^^^^^^^^ File "/usr/local/lib/python3.11/site-packages/coremltools/converters/_converters_entry.py", line 574, in convert mlmodel = mil_convert( ^^^^^^^^^^^^ File "/usr/local/lib/python3.11/site-packages/coremltools/converters/mil/converter.py", line 188, in mil_convert return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/site-packages/coremltools/converters/mil/converter.py", line 212, in _mil_convert proto, mil_program = mil_convert_to_proto( ^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/site-packages/coremltools/converters/mil/converter.py", line 286, in mil_convert_to_proto prog = frontend_converter(model, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/site-packages/coremltools/converters/mil/converter.py", line 108, in call return load(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 80, in load return _perform_torch_convert(converter, debug) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 99, in _perform_torch_convert prog = converter.convert() ^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 519, in convert convert_nodes(self.context, self.graph) File "/usr/local/lib/python3.11/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 88, in convert_nodes add_op(context, node) File "/usr/local/lib/python3.11/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 4209, in _slice res = mb.slice_by_index(kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/site-packages/coremltools/converters/mil/mil/ops/registry.py", line 182, in add_op return cls._add_op(op_cls_to_add, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/site-packages/coremltools/converters/mil/mil/builder.py", line 168, in _add_op new_op = op_cls(kwargs) ^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/site-packages/coremltools/converters/mil/mil/operation.py", line 190, in init self._validate_and_set_inputs(input_kv) File "/usr/local/lib/python3.11/site-packages/coremltools/converters/mil/mil/operation.py", line 503, in _validate_and_set_inputs self.input_spec.validate_inputs(self.name, self.op_type, input_kvs) File "/usr/local/lib/python3.11/site-packages/coremltools/converters/mil/mil/input_type.py", line 163, in validate_inputs raise ValueError(msg.format(name, var.name, input_type.type_str, ValueError: Op "138" (op_type: slice_by_index) Input x="133" expects tensor or scalar of dtype from type domain ['fp16', 'fp32', 'int32', 'bool'] but got tensor[1,2,2049,44,complex64]

To Reproduce

input = torch.rand(1, 2, 40000) traced_module = torch.jit.trace(model, input) ct.convert(traced_module, convert_to="neuralnetwork”, inputs=[ct.TensorType(shape=input.shape)])

System environment (please complete the following information):

Additional context

TobyRoseman commented 10 months ago

It looks like your PyTorch model has an inplace slice operation on a tensor of complex numbers. Given that the Core ML Framework does not support complex numbers, I don't think coremltools is going to be able to support converting this model.

I encourage you to use the Feedback Assistant tool to submit a feature request for the Core ML Framework to support complex numbers.

andreascuderi commented 10 months ago

@TobyRoseman thanks for the answer. According to this post #1539, coremltools supports complex numbers ops since version 6.2. Does the problem strictly depend on the slice op?

YifanShenSZ commented 10 months ago

@junpeiz We managed to support some complex algebra by decomposing the real and the imaginary parts. Is it possible to do the same for slice_by_index?

I guess not, since

  1. Eventually, we need the model to have real output?
  2. slice_by_index would give complex output, violating 1
junpeiz commented 10 months ago

@junpeiz We managed to support some complex algebra by decomposing the real and the imaginary parts. Is it possible to do the same for slice_by_index?

I guess not, since

  1. Eventually, we need the model to have real output?
  2. slice_by_index would give complex output, violating 1

It's possible to do the same for slice_by_index, where the real and imaginary part are sliced individually. The real output requirement is for the whole model, which means as long as the complex output of slice_by_index is comsumed by some following ops which produce a real output, it should be fine.