apple / coremltools

Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.
https://coremltools.readme.io
BSD 3-Clause "New" or "Revised" License
4.44k stars 643 forks source link

Add support for torch op conv1d #1753

Open fursund opened 1 year ago

fursund commented 1 year ago

coremltools is missing conv1d support for torch

TobyRoseman commented 1 year ago

Looking at the PyTorch ops we support, we are missing support for conv1d.

However the following works:

import coremltools as ct
import torch

m_torch = torch.nn.Conv1d(16, 33, 3)
x = torch.randn(20, 16, 50)
m_torch = torch.jit.trace(m_torch, x)

m_cm = ct.convert(m_torch, inputs=[ct.TensorType(shape=x.shape, name="x")])
m_cm.predict({'x': x})

The conv1d torch op must be getting lowered to a op we do support.

Can someone provide a toy example where our lack of conv1d support causes model conversion to fail?

fursund commented 1 year ago

With this https://gist.github.com/fursund/39c897d25f583686fe2626c56b48ffa3 and coremltools 6.2 it will hit the conv1d op

fursund commented 1 year ago

Best bet is that it has something to do with: https://asteroid.readthedocs.io/en/v0.3.1/_modules/asteroid/filterbanks/enc_dec.html#Encoder

TobyRoseman commented 1 year ago

With this https://gist.github.com/fursund/39c897d25f583686fe2626c56b48ffa3 and coremltools 6.2 it will hit the conv1d op

Did you mean to share a different link? The error here is not related to conv1d.

Best bet is that it has something to do with: https://asteroid.readthedocs.io/en/v0.3.1/_modules/asteroid/filterbanks/enc_dec.html#Encoder

There is quite a bit of code on this page. None of it is using coremltools. Can you provide a minimal example were conversion fails because we do not support conv1d?

fursund commented 1 year ago

Ok. Tried to reduce the issue a bit:

import coremltools as ct
import torch
from asteroid_filterbanks import Encoder, ParamSincFB

m_torch = Encoder(
                ParamSincFB(
                    80,
                    251,
                    stride=1,
                    sample_rate=16000,
                    min_low_hz=50,
                    min_band_hz=50,
                )
            )
print(m_torch)
x = torch.randn(10, 1, 1024)
m_torch = torch.jit.trace(m_torch, x)

m_cm = ct.convert(m_torch, inputs=[ct.TensorType(shape=x.shape, name="x")])
TobyRoseman commented 1 year ago

Thanks @fursund. After running pip install asteroid-filterbanks, I can reproduce the problem using your code.

Looks like there is still a lot going on inside of Encoder and ParamSincFb. Ideally, we'd have a toy example (i.e. something we could use as a unit test).

fursund commented 1 year ago

Yeah. Not involved in that project, but when I get a moment I'll try and make the test even more barebones.

yych42 commented 1 year ago

Did anyone make progress on this? I was trying to convert titanet but failed with this issue.