apple / coremltools

Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.
https://coremltools.readme.io
BSD 3-Clause "New" or "Revised" License
4.33k stars 627 forks source link

ct.convert call appears to corrupt torchscript model #2215

Open carsonswope opened 4 months ago

carsonswope commented 4 months ago

🐞Describing the bug

After running ct.convert on a torchscript model, the torchscript model appears to be corrupted and does not save correctly. The stack trace is coming from torch, but it only happens after the model has been processed using ct.convert.

Stack Trace

python version: 3.9.19 (main, May  6 2024, 14:39:30)
[Clang 14.0.6 ]
torch version: 2.2.0
ct version: 7.2
** model loaded correctly before ct.convert
Converting PyTorch Frontend ==> MIL Ops:  75%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š                 | 3/4 [00:00<00:00, 1289.10 ops/s]
Running MIL frontend_pytorch pipeline: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5/5 [00:00<00:00, 6848.96 passes/s]
Running MIL default pipeline:   0%|                                                                                      | 0/78 [00:00<?, ? passes/s]/Users/carson/miniconda3/envs/ct_convert_error/lib/python3.9/site-packages/coremltools/converters/mil/mil/passes/defs/preprocess.py:266: UserWarning: Output, '7', of the source model, has been renamed to 'var_7' in the Core ML model.
  warnings.warn(msg.format(var.name, new_name))
Running MIL default pipeline: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 78/78 [00:00<00:00, 5293.10 passes/s]
Running MIL backend_mlprogram pipeline: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 12/12 [00:00<00:00, 12738.96 passes/s]
Traceback (most recent call last):
  File "/Users/carson/code/bfx/ai/repro.py", line 42, in <module>
    _ = torch.jit.load(f1)
  File "/Users/carson/miniconda3/envs/ct_convert_error/lib/python3.9/site-packages/torch/jit/_serialization.py", line 159, in load
    cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files, _restore_shapes)  # type: ignore[call-arg]
RuntimeError: required keyword attribute 'chunks' is undefined

To Reproduce

Python script:

import torch
import torch.nn as nn
import coremltools as ct
import numpy as np
import sys

f0 = 'tmp0.pt'
f1 = 'tmp1.pt'

print(f'python version: {sys.version}')
print(f'torch version: {torch.__version__}')
print(f'ct version: {ct.__version__}')

class Net(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        a,b,c = x.chunk(3)
        return (a * b) + c

with torch.no_grad():

    i = torch.rand((768, 256))
    net = Net().eval()
    net_traced = torch.jit.trace(net, (i))

    # this works..
    net_traced.save(f0)
    _ = torch.jit.load(f0)
    print('** model loaded correctly before ct.convert')

    ct.convert(
        net_traced,
        convert_to='mlprogram',
        minimum_deployment_target=ct.target.macOS12,
        compute_units=ct.ComputeUnit.ALL,
        inputs=[ct.TensorType(name='i0', shape=(768,25), dtype=np.float32)])

    # this doesnt..
    net_traced.save(f1)
    _ = torch.jit.load(f1)
    print('** model loaded correctly after ct.convert')

System environment (please complete the following information):

TobyRoseman commented 4 months ago

I can reproduce this issue. The error is actually coming when the second saved PyTorch model is being loaded. I suspect this issue is the result of some of the PyTorch graph lowering that we do during conversion.