iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.57k stars 574 forks source link

pack_padded_sequence/pad_packed_sequence from torch.nn.utils.rnn #15291

Open JBloodless opened 11 months ago

JBloodless commented 11 months ago

Request description

Hi. Tried to convert the model with pack_padded_sequence/pad_packed_sequence and found out that they apparently are not supported yet. Since these functions are crucial for some high-perfomance models (which are aimed at fast training and inference) and I turned to IREE to further improve perfomance, I think it would be useful to add support for them.

Minimal repro:

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import shark_turbine.aot as aot
from iree import runtime as ireert

class LSTM(nn.Module):

    def __init__(self,
                 input_size,
                 lstm_h=128,
                 num_layers=2,
                 dropout=0.1,
                 bidirectional=True
                 ):
        super().__init__()

        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=lstm_h,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True,
            bidirectional=bidirectional
        )
        if bidirectional:
            num_directions = 2
        else:
            num_directions = 1
        self.fan_out = num_directions * lstm_h

    def forward(self, x, n_wins, h0=None, c0=None):
        x = pack_padded_sequence(
            x,
            n_wins.cpu(),
            batch_first=True,
            enforce_sorted=False
        )
        # self.lstm.flatten_parameters()
        x, (h, c) = self.lstm(x, (h0, c0))
        x, _ = pad_packed_sequence(
            x,
            batch_first=True,
            padding_value=0.0,
            total_length=n_wins.max())
        return x, h, c

x = torch.zeros((1, 63, 384))
n_wins = 63
h0 = torch.zeros(2, 1, 128)
c0 = torch.zeros(2, 1, 128)

model = LSTM(input_size=384)
model.eval()

example_input = (x, torch.as_tensor(n_wins).unsqueeze(0), h0, c0)
export_output = aot.export(model, *example_input)
binary = export_output.compile(save_to=None)

config = ireert.Config("local-task")
vm_module = ireert.load_vm_module(
    ireert.VmModule.wrap_buffer(config.vm_instance, binary.map_memory()),
    config,
)

out, h0, c0 = vm_module.main(*example_input)

Current error log:

Traceback (most recent call last):
  File "/Users/i.beskrovnyy/tts/NISQA-s/repro.py", line 58, in <module>
    export_output = aot.export(model, *example_input)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/aot/exporter.py", line 198, in export
    cm = Exported(context=context, import_to="import")
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/aot/compiled_module.py", line 534, in __new__
    do_export(proc_def)
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/aot/compiled_module.py", line 531, in do_export
    trace.trace_py_func(invoke_with_self)
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/aot/support/procedural/tracer.py", line 120, in trace_py_func
    return_py_value = _unproxy(py_f(*self.proxy_posargs, **self.proxy_kwargs))
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/aot/compiled_module.py", line 512, in invoke_with_self
    return proc_def.callable(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/aot/exporter.py", line 182, in main
    return jittable(mdl.forward)(*args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/aot/support/procedural/base.py", line 137, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/aot/support/procedural/tracer.py", line 136, in handle_call
    return target.resolve_call(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/aot/builtins/jittable.py", line 207, in resolve_call
    transformed_f = functorch_functionalize(transformed_f, *flat_pytorch_args)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/aot/passes/functorch.py", line 47, in functorch_functionalize
    new_gm = proxy_tensor.make_fx(
             ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 841, in wrapped
    t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 406, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 461, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 817, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 497, in wrapped
    out = f(*tensors)
          ^^^^^^^^^^^
  File "<string>", line 1, in <lambda>
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/aot/passes/functorch.py", line 65, in wrapped
    out = function(*args_functional)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/shark_turbine/aot/builtins/jittable.py", line 202, in flat_wrapped_f
    return self.wrapped_f(*pytorch_args, **pytorch_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/tts/NISQA-s/repro.py", line 34, in forward
    x = pack_padded_sequence(
        ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch/nn/utils/rnn.py", line 264, in pack_padded_sequence
    _VF._pack_padded_sequence(input, lengths, batch_first)
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 574, in __torch_dispatch__
    return self.inner_torch_dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 609, in inner_torch_dispatch
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 354, in proxy_call
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch/_ops.py", line 498, in __call__
    return self._op(*args, **kwargs or {})
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1304, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1601, in dispatch
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/i.beskrovnyy/anaconda3/envs/mlir/lib/python3.11/site-packages/torch/_ops.py", line 498, in __call__
    return self._op(*args, **kwargs or {})
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Cannot call numel() on tensor with symbolic sizes/strides
Exception raised from throw_cannot_call_with_symbolic at /Users/runner/work/pytorch/pytorch/pytorch/c10/core/TensorImpl.cpp:594 (most recent call first):
frame #0: c10::TensorImpl::throw_cannot_call_with_symbolic(char const*) const + 124 (0x10fb0cad0 in libc10.dylib)
frame #1: c10::TensorImpl::numel_custom() const + 332 (0x10fb0d740 in libc10.dylib)
frame #2: at::native::_pack_padded_sequence(at::Tensor const&, at::Tensor const&, bool) + 3116 (0x2806deae8 in libtorch_cpu.dylib)
frame #3: c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::__1::tuple<at::Tensor, at::Tensor> (at::Tensor const&, at::Tensor const&, bool), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeExplicitAutograd___pack_padded_sequence(at::Tensor const&, at::Tensor const&, bool)>, std::__1::tuple<at::Tensor, at::Tensor>, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, bool>>, false>::call(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) + 80 (0x28176fe10 in libtorch_cpu.dylib)
frame #4: c10::Dispatcher::callBoxedForDispatchKey(c10::OperatorHandle const&, c10::DispatchKey, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) const + 268 (0x163c8da3c in libtorch_python.dylib)
frame #5: (anonymous namespace)::ConcretePyInterpreterVTable::python_dispatcher(c10::OperatorHandle const&, c10::DispatchKeySet, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) const + 540 (0x163c82e40 in libtorch_python.dylib)
frame #6: c10::Dispatcher::callBoxed(c10::OperatorHandle const&, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) const + 300 (0x28002d95c in libtorch_cpu.dylib)
frame #7: torch::jit::invokeOperatorFromPython(std::__1::vector<std::__1::shared_ptr<torch::jit::Operator>, std::__1::allocator<std::__1::shared_ptr<torch::jit::Operator>>> const&, pybind11::args, pybind11::kwargs const&, c10::optional<c10::DispatchKey>) + 396 (0x163eb9468 in libtorch_python.dylib)
frame #8: torch::jit::_get_operation_for_overload_or_packet(std::__1::vector<std::__1::shared_ptr<torch::jit::Operator>, std::__1::allocator<std::__1::shared_ptr<torch::jit::Operator>>> const&, c10::Symbol, pybind11::args, pybind11::kwargs const&, bool, c10::optional<c10::DispatchKey>) + 992 (0x163eb9c1c in libtorch_python.dylib)
frame #9: std::__1::enable_if<!std::is_void<pybind11::object>::value, pybind11::object>::type pybind11::detail::argument_loader<pybind11::args, pybind11::kwargs>::call<pybind11::object, pybind11::detail::void_type, torch::jit::initJITBindings(_object*)::$_200::operator()(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&) const::'lambda'(pybind11::args, pybind11::kwargs)&>(torch::jit::initJITBindings(_object*)::$_200::operator()(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&) const::'lambda'(pybind11::args, pybind11::kwargs)&) && + 188 (0x163debc58 in libtorch_python.dylib)
frame #10: void pybind11::cpp_function::initialize<torch::jit::initJITBindings(_object*)::$_200::operator()(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&) const::'lambda'(pybind11::args, pybind11::kwargs), pybind11::object, pybind11::args, pybind11::kwargs>(torch::jit::initJITBindings(_object*)::$_200::operator()(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&) const::'lambda'(pybind11::args, pybind11::kwargs)&&, pybind11::object (*)(pybind11::args, pybind11::kwargs))::'lambda'(pybind11::detail::function_call&)::__invoke(pybind11::detail::function_call&) + 68 (0x163deba58 in libtorch_python.dylib)
frame #11: pybind11::cpp_function::dispatcher(_object*, _object*, _object*) + 3784 (0x16368e6e4 in libtorch_python.dylib)
frame #12: cfunction_call + 60 (0x1028bca68 in python3.11)
frame #13: _PyEval_EvalFrameDefault + 215396 (0x1029a42ec in python3.11)
frame #14: _PyFunction_Vectorcall + 476 (0x10285ad30 in python3.11)
frame #15: _PyObject_Call_Prepend + 164 (0x10285d4fc in python3.11)
frame #16: slot_tp_call + 120 (0x1028ea9ec in python3.11)
frame #17: _PyEval_EvalFrameDefault + 214140 (0x1029a3e04 in python3.11)
frame #18: _PyFunction_Vectorcall + 476 (0x10285ad30 in python3.11)
frame #19: _PyEval_EvalFrameDefault + 213732 (0x1029a3c6c in python3.11)
frame #20: _PyFunction_Vectorcall + 476 (0x10285ad30 in python3.11)
frame #21: method_vectorcall + 164 (0x10286094c in python3.11)
frame #22: _PyObject_CallFunctionVa + 148 (0x10285e100 in python3.11)
frame #23: PyObject_CallMethod + 112 (0x10285e518 in python3.11)
frame #24: torch::handle_torch_function_no_python_arg_parser(c10::ArrayRef<_object*>, _object*, _object*, char const*, _object*, char const*, torch::TorchFunctionName) + 1360 (0x16413f610 in libtorch_python.dylib)
frame #25: (anonymous namespace)::ConcretePyInterpreterVTable::dispatch(c10::OperatorHandle const&, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) const + 452 (0x163c82890 in libtorch_python.dylib)
frame #26: void c10::BoxedKernel::make_boxed_function<&(anonymous namespace)::pythonFallback(c10::OperatorHandle const&, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*)>(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) + 116 (0x28034a84c in libtorch_cpu.dylib)
frame #27: c10::Dispatcher::callBoxedForDispatchKey(c10::OperatorHandle const&, c10::DispatchKey, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) const + 268 (0x163c8da3c in libtorch_python.dylib)
frame #28: (anonymous namespace)::ConcretePyInterpreterVTable::python_dispatcher(c10::OperatorHandle const&, c10::DispatchKeySet, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) const + 540 (0x163c82e40 in libtorch_python.dylib)
frame #29: c10::Dispatcher::callBoxed(c10::OperatorHandle const&, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) const + 300 (0x28002d95c in libtorch_cpu.dylib)
frame #30: torch::jit::invokeOperatorFromPython(std::__1::vector<std::__1::shared_ptr<torch::jit::Operator>, std::__1::allocator<std::__1::shared_ptr<torch::jit::Operator>>> const&, pybind11::args, pybind11::kwargs const&, c10::optional<c10::DispatchKey>) + 396 (0x163eb9468 in libtorch_python.dylib)
frame #31: torch::jit::_get_operation_for_overload_or_packet(std::__1::vector<std::__1::shared_ptr<torch::jit::Operator>, std::__1::allocator<std::__1::shared_ptr<torch::jit::Operator>>> const&, c10::Symbol, pybind11::args, pybind11::kwargs const&, bool, c10::optional<c10::DispatchKey>) + 992 (0x163eb9c1c in libtorch_python.dylib)
frame #32: std::__1::enable_if<!std::is_void<pybind11::object>::value, pybind11::object>::type pybind11::detail::argument_loader<pybind11::args, pybind11::kwargs>::call<pybind11::object, pybind11::detail::void_type, torch::jit::initJITBindings(_object*)::$_200::operator()(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&) const::'lambda'(pybind11::args, pybind11::kwargs)&>(torch::jit::initJITBindings(_object*)::$_200::operator()(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&) const::'lambda'(pybind11::args, pybind11::kwargs)&) && + 188 (0x163debc58 in libtorch_python.dylib)
frame #33: void pybind11::cpp_function::initialize<torch::jit::initJITBindings(_object*)::$_200::operator()(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&) const::'lambda'(pybind11::args, pybind11::kwargs), pybind11::object, pybind11::args, pybind11::kwargs>(torch::jit::initJITBindings(_object*)::$_200::operator()(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&) const::'lambda'(pybind11::args, pybind11::kwargs)&&, pybind11::object (*)(pybind11::args, pybind11::kwargs))::'lambda'(pybind11::detail::function_call&)::__invoke(pybind11::detail::function_call&) + 68 (0x163deba58 in libtorch_python.dylib)
frame #34: pybind11::cpp_function::dispatcher(_object*, _object*, _object*) + 3784 (0x16368e6e4 in libtorch_python.dylib)
frame #35: cfunction_call + 60 (0x1028bca68 in python3.11)
frame #36: _PyEval_EvalFrameDefault + 215396 (0x1029a42ec in python3.11)
frame #37: _PyFunction_Vectorcall + 476 (0x10285ad30 in python3.11)
frame #38: _PyObject_Call_Prepend + 164 (0x10285d4fc in python3.11)
frame #39: slot_tp_call + 120 (0x1028ea9ec in python3.11)
frame #40: _PyEval_EvalFrameDefault + 214140 (0x1029a3e04 in python3.11)
frame #41: _PyFunction_Vectorcall + 476 (0x10285ad30 in python3.11)
frame #42: _PyEval_EvalFrameDefault + 213732 (0x1029a3c6c in python3.11)
frame #43: _PyFunction_Vectorcall + 476 (0x10285ad30 in python3.11)
frame #44: method_vectorcall + 164 (0x10286094c in python3.11)
frame #45: _PyObject_CallFunctionVa + 148 (0x10285e100 in python3.11)
frame #46: PyObject_CallMethod + 112 (0x10285e518 in python3.11)
frame #47: torch::handle_torch_function_no_python_arg_parser(c10::ArrayRef<_object*>, _object*, _object*, char const*, _object*, char const*, torch::TorchFunctionName) + 1360 (0x16413f610 in libtorch_python.dylib)
frame #48: (anonymous namespace)::ConcretePyInterpreterVTable::dispatch(c10::OperatorHandle const&, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) const + 452 (0x163c82890 in libtorch_python.dylib)
frame #49: void c10::BoxedKernel::make_boxed_function<&(anonymous namespace)::pythonFallback(c10::OperatorHandle const&, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*)>(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) + 116 (0x28034a84c in libtorch_cpu.dylib)
frame #50: c10::Dispatcher::callBoxedForDispatchKey(c10::OperatorHandle const&, c10::DispatchKey, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) const + 268 (0x163c8da3c in libtorch_python.dylib)
frame #51: (anonymous namespace)::ConcretePyInterpreterVTable::python_dispatcher(c10::OperatorHandle const&, c10::DispatchKeySet, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) const + 540 (0x163c82e40 in libtorch_python.dylib)
frame #52: void c10::BoxedKernel::make_boxed_function<&(anonymous namespace)::pythonTLSSnapshotFallback(c10::OperatorHandle const&, c10::DispatchKeySet, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*)>(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) + 192 (0x28034ac60 in libtorch_cpu.dylib)
frame #53: c10::Dispatcher::callBoxedForDispatchKey(c10::OperatorHandle const&, c10::DispatchKey, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) const + 268 (0x163c8da3c in libtorch_python.dylib)
frame #54: (anonymous namespace)::ConcretePyInterpreterVTable::python_dispatcher(c10::OperatorHandle const&, c10::DispatchKeySet, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) const + 540 (0x163c82e40 in libtorch_python.dylib)
frame #55: c10::Dispatcher::callBoxed(c10::OperatorHandle const&, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) const + 300 (0x28002d95c in libtorch_cpu.dylib)
frame #56: void c10::BoxedKernel::make_boxed_function<&(anonymous namespace)::functionalizeFallback(c10::OperatorHandle const&, c10::DispatchKeySet, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*)>(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) + 1408 (0x28002c624 in libtorch_cpu.dylib)
frame #57: c10::Dispatcher::callBoxedForDispatchKey(c10::OperatorHandle const&, c10::DispatchKey, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) const + 268 (0x163c8da3c in libtorch_python.dylib)
frame #58: (anonymous namespace)::ConcretePyInterpreterVTable::python_dispatcher(c10::OperatorHandle const&, c10::DispatchKeySet, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) const + 540 (0x163c82e40 in libtorch_python.dylib)
frame #59: c10::impl::BoxedKernelWrapper<std::__1::tuple<at::Tensor, at::Tensor> (at::Tensor const&, at::Tensor const&, bool), void>::call(c10::BoxedKernel const&, c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, bool) + 164 (0x2813014ec in libtorch_cpu.dylib)
frame #60: c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::__1::tuple<at::Tensor, at::Tensor> (c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, bool), &torch::autograd::VariableType::(anonymous namespace)::_pack_padded_sequence(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, bool)>, std::__1::tuple<at::Tensor, at::Tensor>, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, bool>>, std::__1::tuple<at::Tensor, at::Tensor> (c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, bool)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, bool) + 1528 (0x28329f320 in libtorch_cpu.dylib)
frame #61: c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::__1::tuple<at::Tensor, at::Tensor> (c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, bool), &torch::autograd::VariableType::(anonymous namespace)::_pack_padded_sequence(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, bool)>, std::__1::tuple<at::Tensor, at::Tensor>, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, bool>>, false>::call(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) + 84 (0x28329fe40 in libtorch_cpu.dylib)
frame #62: c10::Dispatcher::callBoxedForDispatchKey(c10::OperatorHandle const&, c10::DispatchKey, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) const + 268 (0x163c8da3c in libtorch_python.dylib)
frame #63: (anonymous namespace)::ConcretePyInterpreterVTable::python_dispatcher(c10::OperatorHandle const&, c10::DispatchKeySet, std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>*) const + 540 (0x163c82e40 in libtorch_python.dylib)

What component(s) does this issue relate to?

Frontends, MLIR, Python

Additional context

No response

dan-garvey commented 10 months ago

@JBloodless check the pytorch issue I think they might have fixed it? You'll need to use a version of pytorch from after the patch in that issue

stellaraccident commented 10 months ago

(we also need to get turbine setup with a nightly pytorch channel)

JBloodless commented 10 months ago

@JBloodless check the pytorch issue I think they might have fixed it? You'll need to use a version of pytorch from after the patch in that issue

Still getting the same error on torch 2.2.0.dev20231113. Also, the code runs (with same error), but pip throws

torch-mlir 20230924.971 requires torch==2.2.0.dev20230922, but you have torch 2.2.0.dev20231113 which is incompatible.
shark-turbine 0.9.1.dev3 requires iree-runtime>=20231026.688, but you have iree-runtime 20231004.665 which is incompatible.

which was mentioned by @stellaraccident in previous comment

stellaraccident commented 10 months ago

This is a PyTorch error. Is there a linked PyTorch issue?

If working with a nightly you're going to need to use pip with --no-deps and manage manually. We don't currently test the nighties and the public API is set to change so we're keeping the pin to stable until we do.

JBloodless commented 10 months ago

This is a PyTorch error. Is there a linked PyTorch issue?

If working with a nightly you're going to need to use pip with --no-deps and manage manually. We don't currently test the nighties and the public API is set to change so we're keeping the pin to stable until we do.

Yeah, there is Pytorch issue (couple of comments up), and guys here said that the initial problem with pack/pad could be resolved in nightly, but it is not. torch-mlir and turbine seem to be working with nightlies, I just mentioned the logs from pip to clarify if it's okay (seems to be), thanks for response.

JBloodless commented 8 months ago

I pinged torch issue, but maybe there is some more development here?

JBloodless commented 5 months ago

@stellaraccident @dan-garvey maybe there is any progress here :)

stellaraccident commented 5 months ago

I don't have any secret channel for getting these things resolved in upstream. Also, I don't have any RNNs from pytorch in my immediate priority list so am unlikely to have the time to find any creative solution.