pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.6k stars 351 forks source link

🐛 [Bug] Encountered bug when using Torch-TensorRT(torch.nn.LSTM) #2598

Open johnzlli opened 10 months ago

johnzlli commented 10 months ago

Bug Description

Encountered error as follow when using Torch-TensorRT to convert torch.nn.LSTM in docker image nvcr.io/nvidia/pytorch:23.12-py3 : NotImplementedError: aten::_cudnn_rnn_flatten_weight: attempted to run this operator with Meta tensors, but there was no abstract impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add an abstract impl

To Reproduce

example code:

import torch
import torch_tensorrt
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(1024, 1024, batch_first=True)

    def forward(self, x):
        x = self.lstm(x)[0]
        return x

model = Model().half().eval().cuda()
inputs = [torch.randn(100, 200, 1024).half().cuda()]
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
trt_gm(*inputs)

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

Additional context

johnzlli commented 7 months ago

@narendasan Hi, is there any update?

gs-olive commented 7 months ago

Hello - as an update on this issue, a workaround to try is to compile with ir="torch_compile" and specify torch._dynamo.config.allow_rnn = True at the top of the script.

Regarding the ir="dynamo" path, there is a workaround as specified here: https://github.com/pytorch/pytorch/issues/121761#issuecomment-2021696208, which can then be used with Torch-TensorRT by passing the gm object into the .compile call.

A more robust fix is pending resolution to these related issues: [pytorch/pytorch/issues/120626, pytorch/pytorch/issues/121761]