pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.32k stars 6.97k forks source link

Unable to generate a TensorRT GeneralizedRCNNTransform model with a dynamic batch #5858

Open montmejat opened 2 years ago

montmejat commented 2 years ago

🐛 Describe the bug

I'm trying to generate a TensorRT engine of RetinaNet that uses GeneralizedRCNNTransform. By bypassing a couple of layers, it works fine for a static batch size. However, when using a dynamic batch size, I get an issue that I don't know how to fix.

In class GeneralizedRCNNTransform(nn.Module), there is this part in the forward method:

for i in range(len(images)):
    image = images[i]
    target_index = targets[i] if targets is not None else None

    if image.dim() != 3:
        raise ValueError("images is expected to be a list of 3d tensors "
                            "of shape [C, H, W], got {}".format(image.shape))
    image = self.normalize(image) # here are where the sub and div nodes comes from
    image, target_index = self.resize(image, target_index) # I'm bypassing this
    images[i] = image
    if targets is not None and target_index is not None:
        targets[i] = target_index

When converting my model with:

torch.onnx.export(
    model,
    example,
    onnx_model_path.split('.')[0] + '_dynamic.onnx',
    verbose=False,
    opset_version=11,
    input_names=['input'],
    dynamic_axes={
        'input': {0: 'batch_size'}
    }
)

I get the following tree:

image

Which is prefectly fine for the given example batch size (3 in this case), but this tree will not work for any other batch size once it is converted to a TensorRT engine. I get this kind of error:

[12/10/2021-13:29:40] [TRT] [E] 7: [shapeMachine.cpp::execute::565] Error Code 7: Internal Error (Split_0_0: ISliceLayer has out of bounds access on axis 0
condition '<' violated
Instruction: CHECK_LESS 1 1
)
[12/10/2021-13:29:40] [TRT] [E] 2: [executionContext.cpp::enqueueInternal::366] Error Code 2: Internal Error (Could not resolve slots: )

Is there any way I can bypass this for loop which is creating the split and concat nodes? Do I need to redefine the forward method? I can't really bypass the normalisation, or else I will loose in accuracy. Can I maybe normalize it all at once without that for loop?

Thanks 😄

Versions

Collecting environment information... PyTorch version: 1.10.0 Is debug build: False CUDA used to build PyTorch: 10.2 ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.6 LTS (aarch64) GCC version: (Ubuntu/Linaro 7.5.0-3ubuntu1~18.04) 7.5.0 Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final) CMake version: version 3.10.2 Libc version: glibc-2.25

Python version: 3.6.9 (default, Mar 15 2022, 13:55:28) [GCC 8.4.0] (64-bit runtime) Python platform: Linux-4.9.253-tegra-aarch64-with-Ubuntu-18.04-bionic Is CUDA available: True CUDA runtime version: Could not collect GPU models and configuration: Could not collect Nvidia driver version: Could not collect cuDNN version: Probably one of the following: /usr/lib/aarch64-linux-gnu/libcudnn.so.8.2.1 /usr/lib/aarch64-linux-gnu/libcudnn_adv_infer.so.8.2.1 /usr/lib/aarch64-linux-gnu/libcudnn_adv_train.so.8.2.1 /usr/lib/aarch64-linux-gnu/libcudnn_cnn_infer.so.8.2.1 /usr/lib/aarch64-linux-gnu/libcudnn_cnn_train.so.8.2.1 /usr/lib/aarch64-linux-gnu/libcudnn_ops_infer.so.8.2.1 /usr/lib/aarch64-linux-gnu/libcudnn_ops_train.so.8.2.1 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

Versions of relevant libraries: [pip3] numpy==1.19.5 [pip3] torch==1.10.0 [pip3] torch2trt==0.3.0 [pip3] torchvision==0.11.1 [conda] Could not collect

cc @datumbox @YosuaMichael

datumbox commented 2 years ago

@aurelien-m Thanks for reporting.

This is not currently a use-case we support. It's hard to guide you to the best option because I don't have deep expertise on TensorRT and ONNX. We could potentially look into this on the future, but currently it's hard due to our limited resources. Apologies I can't provide better assistance at this point.