grimoire / torch2trt_dynamic

A pytorch to tensorrt convert with dynamic shape support
MIT License
254 stars 34 forks source link

Any idea on how to support ConvTranspose2d? #2

Closed huliang2016 closed 3 years ago

huliang2016 commented 3 years ago

besides, I found It's not working with interpolate, Test code as follow:

import torch.nn.functional as F
from torch import nn
import tensorrt as trt

class TestModel(torch.nn.Module):
    def forward(self, x):
        xh, xw = int(x.shape[-2] * 2), int(x.shape[-1] * 2)
        x = F.interpolate(x, size=(xh, xw), mode="nearest")
        return x

test_model = TestModel().cuda()
input_shape = (1, 3, 300, 400)
dummy_tensor = torch.randn(input_shape, dtype=torch.float32).cuda()
# output is (1, 3, 600, 800)
print(test_model(dummy_tensor).shape)

# convert test model to trt
import tensorrt as trt
opt_shape_param = [
    [
        [1, 3, 160, 240],   # min
        [1, 3, 800, 1200],   # opt
        [1, 3, 1600, 2400]    # max
    ]
]

with torch.no_grad():
    trt_model = torch2trt(
        test_model,
        [dummy_tensor],
        fp16_mode=False,
        opt_shape_param=opt_shape_param,
    )

# test trt model
dummy_tensor = torch.randn((1, 3, 400, 400), dtype=torch.float32).cuda()
# except output is (1, 3, 800, 800), but actually the output shape is still (1, 3, 600, 800)
print(trt_model(dummy_tensor).shape)
grimoire commented 3 years ago

Hi
ConvTranspose2d convertor did exist. ConvTranspose2d.py. Try if it works on your model.
About the interpolate, lets discuss in the PR.

huliang2016 commented 3 years ago

It seems not working well... test code as follow

import torch.nn.functional as F
from torch import nn
import tensorrt as trt
import torch
from torch2trt import torch2trt

class TestModel(torch.nn.Module):
    def __init__(self, out_dims):
        super(TestModel, self).__init__()
        self.layer = nn.ConvTranspose2d(3, out_dims, 2, 2)

    def forward(self, x):
        return self.layer(x)

test_model = TestModel(256).cuda()
input_shape = (1, 3, 300, 400)
dummy_tensor = torch.randn(input_shape, dtype=torch.float32).cuda()
# output is (1, 256, 600, 800)
print(test_model(dummy_tensor).shape)

# convert test model to trt
import tensorrt as trt
opt_shape_param = [
    [
        [1, 3, 160, 240],   # min
        [1, 3, 800, 1200],   # opt
        [1, 3, 1600, 2400]    # max
    ]
]

with torch.no_grad():
    trt_model = torch2trt(
        test_model,
        [dummy_tensor],
        fp16_mode=False,
        opt_shape_param=opt_shape_param,
    )

# test trt model
dummy_tensor = torch.randn((1, 3, 300, 400), dtype=torch.float32).cuda()
# except output is (1, 256, 800, 800)
print(trt_model(dummy_tensor).shape)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in call(self, *input, kwargs) 530 result = self._slow_forward(*input, *kwargs) 531 else: --> 532 result = self.forward(input, kwargs) 533 for hook in self._forward_hooks.values(): 534 hook_result = hook(self, input, result)

/opt/venv/ocr-detection-detectron2/lib/python3.6/site-packages/torch2trt/torch2trt.py in forward(self, *inputs) 421 422 for i, input_name in enumerate(self.input_names): --> 423 idx = self.engine.get_binding_index(input_name) 424 425 self.context.set_binding_shape(idx, tuple(inputs[i].shape))

AttributeError: 'NoneType' object has no attribute 'get_binding_index'


* it seems `engine` is None
grimoire commented 3 years ago

Ok. I see. I will check it later.

grimoire commented 3 years ago

Hi
I guess it is caused by OOM.

I have reduce the opt_shape_param size and increase max_workspace_size as follow:

opt_shape_param = [
    [
        [1, 3, 160, 240],   # min
        [1, 3, 300, 400],   # opt
        [1, 3, 600, 800]    # max
    ]
]

max_workspace_size = 1<<30
with torch.no_grad():
    trt_model = torch2trt(
        test_model,
        [dummy_tensor],
        fp16_mode=False,
        opt_shape_param=opt_shape_param,
        max_workspace_size=max_workspace_size
    )

It works. CUDNN need a large workspace to do the deconv. Guess (1600, 2400) reach the memory limit.

huliang2016 commented 3 years ago

WTF, thanks!