pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.57k stars 350 forks source link

🐛 [Bug] Runtime error when loading compiled model #1550

Closed Nik-V9 closed 1 year ago

Nik-V9 commented 1 year ago

Bug Description

A runtime error is thrown when compiled torch script model is loaded.

RuntimeError: [Error thrown at core/runtime/TRTEngine.cpp:132] Expected (binding_name == engine_binded_name) to be true but got false
Could not find a TensorRT engine binding for output named output_0

The same code works on a docker image with Torch version: 1.10.0a0+0aef44c, Torch-TensorRT version: 1.0.0+55c3bab4, TensorRT version: 8.0.3-1, and CUDA 11.3.

To Reproduce

Code:

import torch
import torch_tensorrt
import torch.nn as nn
import timm

class TrEstimator(nn.Module):
    def __init__(self, cfg, pretrained=True):
        super().__init__()
        base_model_name = cfg['base_model_name']
        input_depth = 2
        self.weight_scale = cfg['weight_scale']

        self.base_model = timm.create_model(base_model_name,
                                            features_only=True,
                                            in_chans=input_depth,
                                            pretrained=pretrained)

        out_ch = self.base_model.feature_info.channels()[-1]
        self.conv_heatmap = nn.Conv2d(out_ch, 1, kernel_size=1, bias=True)
        self.conv_offset = nn.Conv2d(out_ch, 2, kernel_size=1, bias=True)

    def freeze_encoder(self):
        self.base_model.freeze_encoder()

    def unfreeze_encoder(self):
        self.base_model.unfreeze_encoder()

    def forward(self, prev_frame, cur_frame):
        inputs = torch.stack([prev_frame, cur_frame], dim=1)
        x = self.base_model(inputs)
        x = x[-1]

        x = x[:, :, 2:-2, 2:-2]

        x_hm = self.conv_heatmap(x)
        m = torch.exp(self.weight_scale * torch.sigmoid(x_hm))
        heatmap = m / torch.sum(m, dim=(2, 3), keepdim=True)
        offsets = self.conv_offset(x)

        return heatmap, offsets

model = TrEstimator(
        cfg=dict(
            base_model_name="resnet34",
            weight_scale=3.0
        ), pretrained=False)
model = model.half().cuda()
model.eval()

print('Compiling TensorRT model ...')
model_jit = torch.jit.trace(model, (torch.rand((1, 640, 1024)).half().cuda(), torch.rand((1, 640, 1024)).half().cuda()))
model = torch_tensorrt.compile(
    model_jit, 
    inputs=[torch_tensorrt.Input(shape=[1, 640, 1024], dtype=torch.half), torch_tensorrt.Input(shape=[1, 640, 1024], dtype=torch.half)], 
    enabled_precisions={torch.half},
    truncate_long_and_double=True
)
torch.jit.save(model, f"./model.ts")

print('Loading TensorRT model ...')
model_new = torch.jit.load(f"./model.ts")

Output:

Compiling TensorRT model ...
WARNING: [Torch-TensorRT] - For input prev_frame, found user specified input dtype as Float16. The compiler is going to use the user setting Float16
WARNING: [Torch-TensorRT] - For input cur_frame, found user specified input dtype as Float16. The compiler is going to use the user setting Float16
WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - Sum converter disregards dtype
WARNING: [Torch-TensorRT TorchScript Conversion Context] - TensorRT was linked against cuDNN 8.6.0 but loaded cuDNN 8.5.0
Loading TensorRT model ...
Traceback (most recent call last):
  File "compile.py", line 61, in <module>
    model_new = torch.jit.load(f"./model.ts")
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_serialization.py", line 162, in load
    cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files)
RuntimeError: [Error thrown at core/runtime/TRTEngine.cpp:132] Expected (binding_name == engine_binded_name) to be true but got false
Could not find a TensorRT engine binding for output named output_0

Expected behavior

Model loads successfully.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

Additional context

dpkg -l | grep nvinfer
ii  libnvinfer-bin                  8.5.1-1+cuda11.8                  amd64        TensorRT binaries
ii  libnvinfer-dev                  8.5.1-1+cuda11.8                  amd64        TensorRT development libraries and headers
ii  libnvinfer-plugin-dev           8.5.1-1+cuda11.8                  amd64        TensorRT plugin libraries
ii  libnvinfer-plugin8              8.5.1-1+cuda11.8                  amd64        TensorRT plugin libraries
ii  libnvinfer8                     8.5.1-1+cuda11.8                  amd64        TensorRT runtime libraries
peri044 commented 1 year ago

Can you enable debug logs by doing

with torch_tensorrt.logging.debug():
  model = torch_tensorrt.compile(
      model_jit, 
      inputs=[torch_tensorrt.Input(shape=[1, 640, 1024], dtype=torch.half), torch_tensorrt.Input(shape=[1, 640, 1024], dtype=torch.half)], 
      enabled_precisions={torch.half},
      truncate_long_and_double=True
  )

and attach the full log here ?

Nik-V9 commented 1 year ago

Here is the full log: log.txt

Nik-V9 commented 1 year ago

Thanks!