apple / coremltools

Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.
https://coremltools.readme.io
BSD 3-Clause "New" or "Revised" License
4.41k stars 636 forks source link

Exported model assumes that the input should always be similar to the tracing example #1991

Open hadiidbouk opened 1 year ago

hadiidbouk commented 1 year ago

🐞Describing the bug

The bug isn't detected while exporting the model, no error is shown, however, when I try using the model in Swift I got this error:

Thread 17: Fatal error: 'try!' expression unexpectedly raised an error: Error Domain=com.apple.CoreML Code=0 "MultiArray shape (1 x 27200) does not match the shape (1 x 16000) specified in the model description" UserInfo={NSLocalizedDescription=MultiArray shape (1 x 27200) does not match the shape (1 x 16000) specified in the model description}

On this line:

let output = try! self.inferenceModule.prediction(input: input)

There is a problem in exporting somehow that makes the tracing not work as expected, it keeps assuming that my input is always the same as the one passed to the trace function.

The first thing to think of here is that the tracing is failing, but that's not the case because I am able to export the model using Pytorch lighting and use it with the LibTorch C++ library without any problem.

Stack Trace

When both 'convert_to' and 'minimum_deployment_target' not specified, 'convert_to' is set to "mlprogram" and 'minimum_deployment_targer' is set to ct.target.iOS15 (which is same as ct.target.macOS12). Note: the model will not run on systems older than iOS15/macOS12/watchOS8/tvOS15. In order to make your model run on older system, please set the 'minimum_deployment_target' to iOS14/iOS13. Details please see the link: https://coremltools.readme.io/docs/unified-conversion-api#target-conversion-formats
Support for converting Torch Script Models is experimental. If possible you should use a traced model for conversion.
Tuple detected at graph output. This will be flattened in the converted model.
Converting PyTorch Frontend ==> MIL Ops:   0%|                                                                                                      | 0/486 [00:00<?, ? ops/s]Saving value type of int64 into a builtin type of int32, might lose precision!
Saving value type of int64 into a builtin type of int32, might lose precision!
Saving value type of int64 into a builtin type of int32, might lose precision!
Converting PyTorch Frontend ==> MIL Ops:  71%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰                          | 345/486 [00:00<00:00, 3449.42 ops/s]Saving value type of int64 into a builtin type of int32, might lose precision!
Saving value type of int64 into a builtin type of int32, might lose precision!
Converting PyTorch Frontend ==> MIL Ops: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹| 484/486 [00:00<00:00, 3123.51 ops/s]
Running MIL frontend_pytorch pipeline:   0%|                                                                                                       | 0/5 [00:00<?, ? passes/s]Saving value type of int64 into a builtin type of int32, might lose precision!
Saving value type of int64 into a builtin type of int32, might lose precision!
Running MIL frontend_pytorch pipeline: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5/5 [00:00<00:00, 142.70 passes/s]
Running MIL default pipeline:   0%|                                                                                                               | 0/66 [00:00<?, ? passes/s]Saving value type of float64 into a builtin type of fp32, might lose precision!
Saving value type of float64 into a builtin type of fp32, might lose precision!
Running MIL default pipeline:   6%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–                                                                                                | 4/66 [00:00<00:01, 39.63 passes/s] /python3.9/site-packages/coremltools/converters/mil/mil/passes/defs/preprocess.py:267: UserWarning: Output, 'input57.1', of the source model, has been renamed to 'input57_1' in the Core ML model.
      warnings.warn(msg.format(var.name, new_name))

Running MIL default pipeline: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 66/66 [00:03<00:00, 21.46 passes/s]
Running MIL backend_mlprogram pipeline: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 12/12 [00:00<00:00, 393.06 passes/s]
custom_model = MyCustomModel()
custom_model.eval()

audio_signal = torch.randn(1, 16000)
audio_signal_len = torch.tensor([audio_signal.shape[1]])

scripted_model = torch.jit.trace(
    custom_model.forward, example_inputs=(audio_signal, audio_signal_len)
)

os.remove(exported_model_path)
exported_model_path = os.path.join(
    output_dir, "Model.ts"
)

scripted_model.save(exported_model_path)

torshscript_model = torch.jit.load(exported_model_path)

mlmodel = ct.convert(
    torshscript_model,
    source="pytorch",
    inputs=[
        ct.TensorType(name="input_signal", shape=audio_signal.shape),
        ct.TensorType(name="input_signal_length", shape=audio_signal_len.shape),
    ],
)
exported_model_path = os.path.join(output_dir, "Model.mlpackage")
mlmodel.save(exported_model_path)

System environment (please complete the following information):

TobyRoseman commented 1 year ago

Based on the error message, it seems you are trying to use an input with a different shape than what the model was traced with. This means you need to use Flexible Input Shapes.

hadiidbouk commented 1 year ago

@TobyRoseman So all the flexible input shapes solutions require have kind of limit to the shape input size, why do we need to have a limit? what kind of limitations we have here compared to the Pytorch Lighting export?

TobyRoseman commented 1 year ago

Yes, flexible input shapes require limits. This is a requirement of the Core ML Framework. I'm not familiar enough with PyTorch Lighting export to compare.

hadiidbouk commented 1 year ago

Seems that there is a bug in the convert when I use flexible input shapes πŸ€”:

When both 'convert_to' and 'minimum_deployment_target' not specified, 'convert_to' is set to "mlprogram" and 'minimum_deployment_targer' is set to ct.target.iOS15 (which is same as ct.target.macOS12). Note: the model will not run on systems older than iOS15/macOS12/watchOS8/tvOS15. In order to make your model run on older system, please set the 'minimum_deployment_target' to iOS14/iOS13. Details please see the link: https://coremltools.readme.io/docs/unified-conversion-api#target-conversion-formats
Support for converting Torch Script Models is experimental. If possible you should use a traced model for conversion.
Tuple detected at graph output. This will be flattened in the converted model.
Converting PyTorch Frontend ==> MIL Ops:  25%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–                                                                   | 121/486 [00:00<00:00, 1440.03 ops/s]
Traceback (most recent call last):
  File ".../pytorch-models/export_model.py", line 72, in <module>
    mlmodel = ct.convert(
  File ".../lib/python3.9/site-packages/coremltools/converters/_converters_entry.py", line 551, in convert
    mlmodel = mil_convert(
  File .../lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 188, in mil_convert
    return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
  File ".../lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 212, in _mil_convert
    proto, mil_program = mil_convert_to_proto(
  File ".../lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 286, in mil_convert_to_proto
    prog = frontend_converter(model, **kwargs)
  File ".../lib/python3.9/site-packages/coremltools/converters/mil/converter.py", line 108, in __call__
    return load(*args, **kwargs)
  File ".../lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 75, in load
    return _perform_torch_convert(converter, debug)
  File ".../lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 114, in _perform_torch_convert
    prog = converter.convert()
  File ".../lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 484, in convert
    convert_nodes(self.context, self.graph)
  File .../lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 93, in convert_nodes
    add_op(context, node)
  File ".../lib/python3.9/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 1628, in pad
    if pad.val is not None:
AttributeError: 'list' object has no attribute 'val'

Code:

rangeDim = ct.RangeDim(lower_bound=16000, upper_bound=16000 * 100, default=16000)
input_signal_shape = ct.Shape(shape=(1, rangeDim))
input_signal_len_shape = ct.Shape(shape=[rangeDim])

mlmodel = ct.convert(
    torshscript_model,
    source="pytorch",
    inputs=[
        ct.TensorType(name="input_signal", shape=input_signal_shape),
        ct.TensorType(name="input_signal_length", shape=input_signal_len_shape),
    ]
)
os.remove(exported_model_path)
exported_model_path = os.path.join(output_dir, "Model.mlpackage")
mlmodel.save(exported_model_path)
TobyRoseman commented 1 year ago

Try calling torch.jit.trace on your PyTorch model prior to conversion.

hadiidbouk commented 1 year ago

But that is what I am currently doing πŸ€”

TobyRoseman commented 1 year ago

But that is what I am currently doing πŸ€”

It doesn't seem so. Note this line in your output:

Support for converting Torch Script Models is experimental. If possible you should use a traced model for conversion.
hadiidbouk commented 1 year ago

@TobyRoseman Here is my full code (sorry the scripted_model is confusing) :

custom_model = MyCustomModel()
custom_model.eval()

audio_signal = torch.randn(1, 16000 * 100)
audio_signal_len = torch.tensor([audio_signal.shape[1]])

scripted_model = torch.jit.trace(
    custom_model.forward, example_inputs=(audio_signal, audio_signal_len)
)

os.remove(exported_model_path)
exported_model_path = os.path.join(
    output_dir, "MyModel.ts"
)

scripted_model.save(exported_model_path)

torshscript_model = torch.jit.load(exported_model_path)

mlmodel = ct.convert(
    scripted_model,
    source="pytorch",
    inputs=[
        ct.TensorType(
            name="inputSignal",
            shape=(
                1,
                ct.RangeDim(16000, 16000 * 100),
            ),
            dtype=np.float32,
        ),
        ct.TensorType(
            name="inputSignalLength",
            shape=(ct.RangeDim(16000, 16000 * 100),),
            dtype=np.int64,
        ),
    ]
)
os.remove(exported_model_path)
exported_model_path = os.path.join(output_dir, "MyModel.mlpackage")
mlmodel.save(exported_model_path)
TobyRoseman commented 1 year ago

Are you still getting the following warning?

Support for converting Torch Script Models is experimental. If possible you should use a traced model for conversion.

If so, then I don't think your model is actually traced. Here is the check for that warning.

Perhaps part of your model is tagged with the @torch.jit.script decorator which I believe prevents it from getting traced.

Also I'm not sure why the first parameter to torch.jit.trace is custom_model.forward rather than just custom_model. I'm not sure if that could be causing issues.

Since you didn't share the implementation of MyCustomModel, I can't reproduce that. If I can reproduce this issue, I'll take a deeper look.

hadiidbouk commented 1 year ago

No, I am not getting the warning anymore. I was getting the warning when I was tracing the model saving it in a file then loading it again.

I can guarantee that the model is traced since it's already working with LibTorch.

Here is the full implementation.

But I believe the problem could be related to #1921, it seems the same case to me

xorange commented 12 months ago

@hadiidbouk : Could you try https://github.com/apple/coremltools/pull/2050 to see if it fixes this problem ? Or could you provide a standalone minimum example for reproduce ?

I cannot reproduce it for this line:

custom_model = MyCustomModel()