apple / coremltools

Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.
https://coremltools.readme.io
BSD 3-Clause "New" or "Revised" License
4.38k stars 631 forks source link

Support LSTMCell layer from PyTorch #1344

Open dzhelonkin opened 2 years ago

dzhelonkin commented 2 years ago

🌱 Describe your Feature Request

Adding conversion support from PyTorch not only for LSTM, but also for LSTMCell. At least unsafe_chunk op support is required.

Use cases

LSTMCell is needed for streaming(online) processing on device . For example it will be useful for speech processing or any other time series.

Describe alternatives you've considered

Conversion PyTorch->ONNX->CoreML is working, but ONNX->CoreML will be deprecated.

Additional context

Working LSTM conversion

import torch
from torch import nn
import coremltools as ct

lstm = nn.LSTM(3, 3)
lstm.eval()
inputs = torch.randn(12, 1, 3)
traced_model = torch.jit.trace(lstm, inputs)

ct.convert(
    model=traced_model,
    inputs=[
        ct.TensorType(name="sequence", shape=(ct.RangeDim(1, 50), 1, 3))
    ]
)

Failed conversion of LSTMCell with error PyTorch convert function for op 'unsafe_chunk' not implemented.

import torch
from torch import nn
import coremltools as ct

class OneStep(nn.Module):
    def __init__(self):
        super(OneStep, self).__init__()        
        self.lstmcell = nn.LSTMCell(input_size=3, hidden_size=3)

    def forward(self, sequence, hidden_state, cell_state):
        return self.lstmcell(sequence, (hidden_state, cell_state))

lstm_cell = OneStep()
lstm_cell.eval()
dummy_input = [torch.zeros((1, 3)), torch.zeros((1, 3)), torch.zeros((1, 3))]
traced_model = torch.jit.trace(lstm_cell, dummy_input)

ct.convert(
    model=traced_model,
    inputs=[
        ct.TensorType(name="sequence", shape=(1, 3)), 
        ct.TensorType(name="hidden_state", shape=(1, 3)), 
        ct.TensorType(name="cell_state", shape=(1, 3)), 
    ]
)
TobyRoseman commented 2 years ago

@dzhelonkin - Your first code segment works for me using the most recent version of coremltools. Please try again with the latest version of coremltools. If you still get the error, let us know more information about your environment, such as your version of PyTorch.

dzhelonkin commented 2 years ago

@TobyRoseman First code segment is working as expected (LSTM, not LSTMCell). The feature request is to add support for LSTMCell conversion from PyTorch (second segment, it is not working). LSTM is not LSTMCell in for/while loop in PyTorch (LSTM and LSTMCell implementations are separated). As a result additional functionality is required.

Environment: torch 1.10.0 coremltools 5.1.0