pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.53k stars 349 forks source link

🐛 [Bug] `torchvision.ops.roi_align` is throwing an internal assert failed #1157

Closed BrettRyland closed 1 year ago

BrettRyland commented 2 years ago

Bug Description

torchvision.ops.roi_align is throwing an internal assert failure when compiling a model using it with torch_tensorrt.compile.

To Reproduce

Repro script: trt_bug3.py

brett@br-workhorse:~/repos/Autosensor/NN/tmp$ python3 trt_bug3.py 
torch.Size([20, 3, 6, 9])
torch.Size([20, 3, 6, 9])
WARNING: [Torch-TensorRT] - Cannot infer input type from calcuations in graph for input x.1. Assuming it is Float32. If not, specify input type explicity
WARNING: [Torch-TensorRT] - Input type for doing shape analysis could not be determined, defaulting to F32
WARNING: [Torch-TensorRT] - Truncating graph input type from at::kLong to at::kInt
WARNING: [Torch-TensorRT] - Truncating graph input type from at::kLong to at::kInt
WARNING: [Torch-TensorRT] - Truncating graph input type from at::kLong to at::kInt
WARNING: [Torch-TensorRT] - Truncating graph input type from at::kLong to at::kInt
WARNING: [Torch-TensorRT] - Truncating graph input type from at::kLong to at::kInt
WARNING: [Torch-TensorRT] - Truncating graph input type from at::kLong to at::kInt
WARNING: [Torch-TensorRT] - Truncating graph input type from at::kLong to at::kInt
WARNING: [Torch-TensorRT] - Truncating graph input type from at::kLong to at::kInt
WARNING: [Torch-TensorRT] - Truncating graph input type from at::kLong to at::kInt
WARNING: [Torch-TensorRT] - Truncating graph input type from at::kLong to at::kInt
WARNING: [Torch-TensorRT] - Truncating graph input type from at::kLong to at::kInt
WARNING: [Torch-TensorRT] - Truncating graph input type from at::kLong to at::kInt
WARNING: [Torch-TensorRT] - Truncating graph input type from at::kLong to at::kInt
WARNING: [Torch-TensorRT] - Truncating graph input type from at::kLong to at::kInt
WARNING: [Torch-TensorRT] - Truncating graph input type from at::kLong to at::kInt
WARNING: [Torch-TensorRT] - Truncating graph input type from at::kLong to at::kInt
WARNING: [Torch-TensorRT TorchScript Conversion Context] - TensorRT was linked against cuBLAS/cuBLAS LT 11.8.0 but loaded cuBLAS/cuBLAS LT 111.0.1
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT TorchScript Conversion Context] - TensorRT was linked against cuBLAS/cuBLAS LT 11.8.0 but loaded cuBLAS/cuBLAS LT 111.0.1
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT TorchScript Conversion Context] - TensorRT was linked against cuBLAS/cuBLAS LT 11.8.0 but loaded cuBLAS/cuBLAS LT 111.0.1
WARNING: [Torch-TensorRT TorchScript Conversion Context] - Unused Input: input_1
WARNING: [Torch-TensorRT TorchScript Conversion Context] - [RemoveDeadLayers] Input Tensor input_1 is unused or used only at compile-time, but is not being removed.
WARNING: [Torch-TensorRT TorchScript Conversion Context] - TensorRT was linked against cuBLAS/cuBLAS LT 11.8.0 but loaded cuBLAS/cuBLAS LT 111.0.1
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT TorchScript Conversion Context] - Unused Input: input_3
WARNING: [Torch-TensorRT TorchScript Conversion Context] - [RemoveDeadLayers] Input Tensor input_3 is unused or used only at compile-time, but is not being removed.
WARNING: [Torch-TensorRT TorchScript Conversion Context] - TensorRT was linked against cuBLAS/cuBLAS LT 11.8.0 but loaded cuBLAS/cuBLAS LT 111.0.1
WARNING: [Torch-TensorRT TorchScript Conversion Context] - TensorRT was linked against cuBLAS/cuBLAS LT 11.8.0 but loaded cuBLAS/cuBLAS LT 111.0.1
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT TorchScript Conversion Context] - TensorRT was linked against cuBLAS/cuBLAS LT 11.8.0 but loaded cuBLAS/cuBLAS LT 111.0.1
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT TorchScript Conversion Context] - TensorRT was linked against cuBLAS/cuBLAS LT 11.8.0 but loaded cuBLAS/cuBLAS LT 111.0.1
WARNING: [Torch-TensorRT TorchScript Conversion Context] - Unused Input: input_1
WARNING: [Torch-TensorRT TorchScript Conversion Context] - [RemoveDeadLayers] Input Tensor input_1 is unused or used only at compile-time, but is not being removed.
WARNING: [Torch-TensorRT TorchScript Conversion Context] - TensorRT was linked against cuBLAS/cuBLAS LT 11.8.0 but loaded cuBLAS/cuBLAS LT 111.0.1
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT TorchScript Conversion Context] - Unused Input: input_3
WARNING: [Torch-TensorRT TorchScript Conversion Context] - [RemoveDeadLayers] Input Tensor input_3 is unused or used only at compile-time, but is not being removed.
WARNING: [Torch-TensorRT TorchScript Conversion Context] - TensorRT was linked against cuBLAS/cuBLAS LT 11.8.0 but loaded cuBLAS/cuBLAS LT 111.0.1
Traceback (most recent call last):
  File "/home/brett/repos/Autosensor/NN/tmp/trt_bug3.py", line 79, in <module>
    trt_m = torch_tensorrt.compile(scr_m, inputs=[torch_tensorrt.Input((2, 3, 10, 10)), ], truncate_long_and_double=True)
  File "/home/brett/.local/lib/python3.10/site-packages/torch_tensorrt/_compile.py", line 115, in compile
    return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
  File "/home/brett/.local/lib/python3.10/site-packages/torch_tensorrt/ts/_compiler.py", line 113, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/home/brett/.local/lib/python3.10/site-packages/torchvision/ops/_utils.py", line 30, in forward
def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]):
    if isinstance(boxes, (list, tuple)):
        for _tensor in boxes:
        ~~~~~~~~~~~~~~~~~~~~~
            assert (
            ~~~~~~~~
                _tensor.size(1) == 4
                ~~~~~~~~~~~~~~~~~~~~
            ), "The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]"
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    elif isinstance(boxes, torch.Tensor):
        assert boxes.size(1) == 5, "The boxes tensor shape is not correct as Tensor[K, 5]"
RuntimeError: isList()INTERNAL ASSERT FAILED at "/home/brett/github/pytorch/aten/src/ATen/core/ivalue_inl.h":1884, please report a bug to PyTorch. Expected GenericList but got Tensor

Expected behavior

Expected graph compilation to succeed.

Environment

brett@br-workhorse:~/repos/Autosensor/NN/tmp$ python3 -c "from torch.utils.collect_env import main; main()"
Collecting environment information...
PyTorch version: 1.11.0a0+gitbc2c6ed
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04 LTS (x86_64)
GCC version: (Ubuntu 11.2.0-19ubuntu1) 11.2.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.4 (main, Apr  2 2022, 09:04:19) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-37-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.7.64
GPU models and configuration: GPU 0: NVIDIA GeForce GTX 1080 Ti
Nvidia driver version: 515.48.07
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.3.2
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.21.6
[pip3] numpy-quaternion==2022.4.1
[pip3] pytorch-ranger==0.1.1
[pip3] torch==1.11.0
[pip3] torch-optimizer==0.1.0
[pip3] torch-tensorrt==1.2.0a0+22d91f5e
[pip3] torchmetrics==0.7.3
[pip3] torchsummary==1.5.1
[pip3] torchvision==0.12.0
[conda] Could not collect

TensorRT is built from 91a92ca4 with PR #1067 merged in

brett@br-workhorse:~/github/TensorRT$ git log --oneline --graph
* 22d91f5e (HEAD -> master) fix: fix the bug that tag Constant node as fallback node
*   ccb826e7 Merge remote-tracking branch 'origin' into refactor_segmentation
|\  
| * 91a92ca4 (origin/master, origin/HEAD) docs: [Automated] Regenerating documenation for dcf3386
| *   dcf3386e Merge pull request #1057 from pytorch/anuragd/reduce_nox_sessions
BrettRyland commented 2 years ago

Note: I've been able to partially work around this (partially because I'm getting a different issue afterwards) by modifying torchvision/ops with

diff -x __pycache__ /home/brett/github/torchvision/torchvision/ops/ps_roi_align.py /home/brett/.local/lib/python3.10/site-packages/torchvision/ops/ps_roi_align.py
2a3
> from torch.jit.annotations import BroadcastingList2
13c14
<     output_size: int,
---
>     output_size: BroadcastingList2[int],
67c68
<         output_size: int,
---
>         output_size: BroadcastingList2[int],
diff -x __pycache__ /home/brett/github/torchvision/torchvision/ops/roi_align.py /home/brett/.local/lib/python3.10/site-packages/torchvision/ops/roi_align.py
59c59,61
<     if not isinstance(rois, torch.Tensor):
---
>     if isinstance(rois, torch.Tensor):
>         pass
>     else:
diff -x __pycache__ /home/brett/github/torchvision/torchvision/ops/_utils.py /home/brett/.local/lib/python3.10/site-packages/torchvision/ops/_utils.py
29c29,31
<     if isinstance(boxes, (list, tuple)):
---
>     if isinstance(boxes, torch.Tensor):
>         assert boxes.size(1) == 5, "The boxes tensor shape is not correct as Tensor[K, 5]"
>     elif isinstance(boxes, (list, tuple)):
34,35d35
<     elif isinstance(boxes, torch.Tensor):
<         assert boxes.size(1) == 5, "The boxes tensor shape is not correct as Tensor[K, 5]"

so it looks like something is messed up with isinstance for tensors.

BrettRyland commented 2 years ago

As a temporary work-around while this and other issues are getting fixed, is it possible to tell torch_tensorrt.compile to ignore certain sections of the model? E.g., to compile the backbone of the network, but leave the rest as scripted or traced?

BrettRyland commented 2 years ago

As a temporary work-around while this and other issues are getting fixed, is it possible to tell torch_tensorrt.compile to ignore certain sections of the model? E.g., to compile the backbone of the network, but leave the rest as scripted or traced?

I see that the torch_executed_modules is intended for this, but I haven't been able to get it to work. How does one specify these modules? E.g., in this simple script trt_bug4.py everything always gets compiled despite all the variants I'm passing to torch_executed_modules.

github-actions[bot] commented 1 year ago

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

github-actions[bot] commented 1 year ago

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days