NVIDIA-AI-IOT / torch2trt

An easy to use PyTorch to TensorRT converter
MIT License
4.49k stars 670 forks source link

Can't convert Vision Transformer from timm - TypeError: add_constant(): incompatible function arguments. #882

Open jakubhejhal opened 10 months ago

jakubhejhal commented 10 months ago

The following code:

import torch
from torch2trt import torch2trt
import timm

vit_model = timm.create_model(
    model_name="vit_base_patch16_224_dino",
    pretrained=True,
)

vit_model.to("cuda")
vit_model.eval()

input_shape = (4, 3, 224, 224)
input_data = torch.rand(input_shape, dtype=torch.float32).cuda()
trt_model = torch2trt(vit_model, [input_data])

Gives the following error:

Warning: Encountered known unsupported method torch.Tensor.unbind
Traceback (most recent call last):
  File "/torch2trt_tests/tensorrt_experiments/torch2trt_timm_vis_transformer_pure_timm.py", line 15, in <module>
    trt_model = torch2trt(vit_model, [input_data])
  File "/venv/lib/python3.10/site-packages/torch2trt-0.4.0-py3.10.egg/torch2trt/torch2trt.py", line 779, in torch2trt
    outputs = module(*inputs)
  File "/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/venv/lib/python3.10/site-packages/timm/models/vision_transformer.py", line 549, in forward
    x = self.forward_features(x)
  File "/venv/lib/python3.10/site-packages/timm/models/vision_transformer.py", line 538, in forward_features
    x = self.blocks(x)
  File "/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/venv/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/venv/lib/python3.10/site-packages/timm/models/vision_transformer.py", line 268, in forward
    x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
  File "/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/venv/lib/python3.10/site-packages/timm/models/vision_transformer.py", line 220, in forward
    attn = (q @ k.transpose(-2, -1)) * self.scale
  File "/venv/lib/python3.10/site-packages/torch2trt-0.4.0-py3.10.egg/torch2trt/torch2trt.py", line 310, in wrapper
    converter["converter"](ctx)
  File "/venv/lib/python3.10/site-packages/torch2trt-0.4.0-py3.10.egg/torch2trt/converters/transpose.py", line 26, in convert_transpose_trt7
    input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
  File "/venv/lib/python3.10/site-packages/torch2trt-0.4.0-py3.10.egg/torch2trt/torch2trt.py", line 176, in add_missing_trt_tensors
    t._trt = network.add_constant(shape, weight).get_output(0)
  File "/venv/lib/python3.10/site-packages/torch2trt-0.4.0-py3.10.egg/torch2trt/torch2trt.py", line 400, in wrapper
    ret = attr(*args, **kwargs)
TypeError: add_constant(): incompatible function arguments. The following argument types are supported:
    1. (self: tensorrt_bindings.tensorrt.INetworkDefinition, shape: tensorrt_bindings.tensorrt.Dims, weights: tensorrt_bindings.tensorrt.Weights) -> tensorrt_bindings.tensorrt.IConstantLayer

Invoked with: <tensorrt_bindings.tensorrt.INetworkDefinition object at 0x7f01f07f96f0>, (4, 12, 197, 64), array([[[[-1.72472194e-01, -1.28571704e-01, -1.14402324e-01, ...,
          -1.26545317e-02,  5.78989200e-02, -6.65028393e-02],
         [-4.51039612e-01,  4.68602143e-02, -3.21930200e-01, ...,
         ....
         ...,
          -1.62460953e-01, -3.48698556e-01,  1.55591115e-01]]]],
      dtype=float32)

I'm using:

torch==2.0.1
torch2trt==0.4.0
timm==0.6.13
shanek16 commented 9 months ago

Same error here using torch2trt version: [36656b6].

Executing the python script:

import torch
from torch2trt import torch2trt
import tensorrt
print('torch version: ', torch.__version__)
print('tensorrt version: ', tensorrt.__version__)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Load MiDaS from hub
midas = torch.hub.load("intel-isl/MiDaS", "DPT_Large")
midas.to(device)
midas.half().eval()

data = torch.randn((1, 3, 384, 672)).cuda().half()

model_trt = torch2trt(midas, [data], fp16_mode=True)

Gives the following error:

torch version: 2.0.0+nv23.05 tensorrt version: 8.5.2.2 Using cache found in /home/shane/.cache/torch/hub/intel-isl_MiDaS_master Warning: Encountered known unsupported method torch.Tensor.len Traceback (most recent call last): File "midas2trt.py", line 15, in model_trt = torch2trt(midas, [data], fp16_mode=True, strict_type_constraints=True) File "/home/shane/Project/torch2trt/torch2trt/torch2trt.py", line 779, in torch2trt outputs = module(inputs) File "/home/shane/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl result = forward_call(args, *kwargs) File "/home/shane/.cache/torch/hub/intel-isl_MiDaS_master/midas/dpt_depth.py", line 166, in forward return super().forward(x).squeeze(dim=1) File "/home/shane/.cache/torch/hub/intel-isl_MiDaS_master/midas/dpt_depth.py", line 114, in forward layers = self.forward_transformer(self.pretrained, x) File "/home/shane/.cache/torch/hub/intel-isl_MiDaS_master/midas/backbones/vit.py", line 13, in forward_vit return forward_adapted_unflatten(pretrained, x, "forward_flex") File "/home/shane/.cache/torch/hub/intel-isl_MiDaS_master/midas/backbones/utils.py", line 86, in forward_adapted_unflatten exec(f"glob = pretrained.model.{function_name}(x)") File "", line 1, in File "/home/shane/.cache/torch/hub/intel-isl_MiDaS_master/midas/backbones/vit.py", line 41, in forward_flex pos_embed = self._resize_pos_embed( File "/home/shane/.cache/torch/hub/intel-isl_MiDaS_master/midas/backbones/vit.py", line 30, in _resize_pos_embed posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") File "/home/shane/Project/torch2trt/torch2trt/torch2trt.py", line 310, in wrapper converter"converter" File "/home/shane/Project/torch2trt/torch2trt/converters/interpolate.py", line 63, in convert_interpolate_trt7 input_trt = add_missing_trt_tensors(ctx.network, [input])[0] File "/home/shane/Project/torch2trt/torch2trt/torch2trt.py", line 176, in add_missing_trt_tensors t._trt = network.add_constant(shape, weight).get_output(0) File "/home/shane/Project/torch2trt/torch2trt/torch2trt.py", line 400, in wrapper ret = attr(args, **kwargs) TypeError: add_constant(): incompatible function arguments. The following argument types are supported:

  1. (self: tensorrt.tensorrt.INetworkDefinition, shape: tensorrt.tensorrt.Dims, weights: tensorrt.tensorrt.Weights) -> tensorrt.tensorrt.IConstantLayer

Invoked with: <tensorrt.tensorrt.INetworkDefinition object at 0xffff1028a7f0>, (1024, 24, 24), array([[[[ 3.0579e-02, 1.6296e-02, 5.4321e-03, ..., 8.8692e-04, 4.4289e-03, 4.6730e-03], [ 2.7710e-02, 1.8799e-02, 2.9659e-03, ..., -5.9929e-03, -2.0809e-03, 6.8550e-03], [ 1.7151e-02, 1.4435e-02, 1.0551e-02, ..., -3.3550e-03, -2.8324e-03, 4.1161e-03], ..., [ 2.6443e-02, 1.2093e-02, 1.1108e-02, ..., 7.6828e-03, 3.6526e-03, 1.2993e-02], [ 1.8646e-02, 1.3718e-02, 1.6861e-02, ..., 1.1475e-02, 7.1182e-03, 8.3008e-03], [ 1.2863e-02, 1.1520e-02, 1.5106e-02, ..., 1.4595e-02, 8.3303e-04, 8.6117e-04]],

    [[-2.6550e-03,  9.3842e-03,  1.0117e-02, ..., -1.1024e-02,
      -4.7874e-03,  1.2131e-02],
     [-4.5052e-03,  5.0163e-03,  3.4180e-03, ..., -4.2648e-03,
       6.3972e-03,  1.0445e-02],
     [-1.5244e-02, -5.4169e-03,  1.5535e-03, ...,  8.1024e-03,
       2.5749e-03,  4.5433e-03],
     ...,
     [ 3.4657e-03,  5.9128e-03,  6.0425e-03, ...,  4.9820e-03,
       2.5439e-04,  7.9422e-03],
     [-1.6890e-03,  7.6981e-03,  1.0254e-02, ..., -1.8415e-03,
       1.3008e-03,  5.3673e-03],
     [-1.3306e-02,  8.9884e-04,  4.3678e-03, ..., -1.4858e-03,
      -1.6670e-03, -8.3971e-04]],

    [[ 1.1230e-02, -1.3218e-03, -1.6174e-02, ..., -1.1299e-02,
      -2.2316e-03,  2.6264e-03],
     [ 2.8625e-02,  1.0063e-02, -8.9874e-03, ..., -8.9722e-03,
      -5.3406e-03,  7.0953e-03],
     [ 3.5645e-02,  1.1154e-02, -1.1566e-02, ..., -1.0147e-02,
       5.9271e-04,  1.6708e-02],
     ...,
     [ 5.3644e-04,  7.3433e-04, -8.1940e-03, ...,  5.7640e-03,
       1.1017e-02,  1.3222e-02],
     [ 1.0284e-02, -6.8893e-03, -1.9470e-02, ...,  1.1311e-03,
       1.2550e-02,  1.0445e-02],
     [ 2.0020e-02, -6.2637e-03, -2.8351e-02, ..., -8.7433e-03,
      -3.0403e-03, -4.7913e-03]],

    ...,

    [[ 4.4769e-02,  5.8411e-02,  5.4565e-02, ...,  3.5156e-02,
       3.0716e-02,  2.7786e-02],
     [ 4.7211e-02,  5.5176e-02,  4.8798e-02, ...,  1.9547e-02,
       1.2978e-02,  2.2446e-02],
     [ 3.0609e-02,  4.0070e-02,  3.1036e-02, ..., -2.0676e-03,
       7.0810e-04,  1.0612e-02],
     ...,
     [ 2.4292e-02,  4.3427e-02,  4.2664e-02, ...,  4.5657e-05,
      -6.5765e-03,  8.6164e-04],
     [ 4.4632e-04,  2.1652e-02,  1.7090e-02, ..., -2.3163e-02,
      -2.4368e-02, -1.6266e-02],
     [-2.3773e-02, -1.1459e-02, -1.8829e-02, ..., -3.8055e-02,
      -4.2053e-02, -3.2867e-02]],

    [[-4.9286e-03, -1.0094e-02, -6.1684e-03, ..., -1.3992e-02,
      -7.7629e-03, -3.9597e-03],
     [-4.9934e-03, -3.4142e-03,  5.1308e-03, ..., -1.0300e-03,
      -2.7370e-03, -6.5384e-03],
     [-5.4092e-03,  8.8024e-04, -5.7936e-04, ...,  3.0479e-03,
       4.8294e-03,  2.6417e-03],
     ...,
     [-1.5915e-02, -2.5406e-03,  5.1765e-03, ..., -1.2680e-02,
      -6.5956e-03,  2.1305e-03],
     [-4.7302e-04,  1.0025e-02,  9.2239e-03, ..., -4.4136e-03,
       3.4981e-03,  1.5167e-02],
     [ 3.3661e-02,  1.9211e-02,  1.3771e-02, ...,  8.0414e-03,
       1.3657e-02,  2.7832e-02]],

    [[-8.9493e-03,  4.8676e-03,  1.5511e-02, ...,  2.1572e-03,
      -8.4763e-03, -3.3478e-02],
     [-1.2062e-02, -6.0425e-03,  4.1962e-03, ...,  7.0229e-03,
      -2.0386e-02, -4.2450e-02],
     [-1.2749e-02, -4.9257e-04,  7.2594e-03, ..., -2.9335e-03,
      -1.8845e-02, -4.4006e-02],
     ...,
     [ 2.1301e-02,  3.2349e-02,  3.3569e-02, ...,  1.2421e-02,
       8.3084e-03, -6.3438e-03],
     [ 9.8648e-03,  2.1255e-02,  2.6810e-02, ...,  1.6281e-02,
       3.7117e-03, -5.2223e-03],
     [-1.0017e-02,  1.8311e-03,  1.4275e-02, ...,  1.5045e-02,
       4.3869e-03, -2.0721e-02]]]], dtype=float16)

I tried these solutions from other threads:

  1. strict_type_constraints=True

What could be the cause of this error? Can anyone help?

Since I could not find the solution form torch2trt, I am moving on to try torch->onnx -> engine for now as @fused-byte suggested. at here

usama-baloch commented 9 months ago

@jakubhejhal @shanek16 The add_constant layer expects weight argument as trt.Weights type but I guess your weight is as numpy array. we can't directly feed numpy array to the add_constant layer, you need to wrap it into trt.Weights to make it compatible. Try updating the failed line:

t._trt = network.add_constant(shape, trt.Weights(weight)).get_output(0)

Hukongtao commented 2 weeks ago

@jakubhejhal @shanek16 The add_constant layer expects weight argument as trt.Weights type but I guess your weight is as numpy array. we can't directly feed numpy array to the add_constant layer, you need to wrap it into trt.Weights to make it compatible. Try updating the failed line:

t._trt = network.add_constant(shape, trt.Weights(weight)).get_output(0)

I solved the above error according to your suggestion, but encountered new problems during the conversion process image

usama-baloch commented 2 weeks ago

@jakubhejhal @shanek16 The add_constant layer expects weight argument as trt.Weights type but I guess your weight is as numpy array. we can't directly feed numpy array to the add_constant layer, you need to wrap it into trt.Weights to make it compatible. Try updating the failed line: t._trt = network.add_constant(shape, trt.Weights(weight)).get_output(0)

I solved the above error according to your suggestion, but encountered new problems during the conversion process image

The Last Error occurs in the slice layer of the network, debug your code and check the inputs and outputs of the slice layers you are using during conversion.

for more info about how slice layers work, look into this: Slice Layer