NVIDIA-AI-IOT / torch2trt

An easy to use PyTorch to TensorRT converter
MIT License
4.56k stars 673 forks source link

Interpolate Plugin [nn.Upsample vs F. Interpolate] #158

Closed SrivastavaKshitij closed 4 years ago

SrivastavaKshitij commented 4 years ago

Hi John

While working on the pull request, I figured out that if I use interpolate plugin on F.interpolate it works fine otherwise there is an error when nn.Upsample is used. This happens for bilinear mode. I was able to reproduce the error:

  1. Change class Interpolate def to:
class Interpolate(torch.nn.Module):
    def __init__(self, size, mode, align_corners,scale_factor=None):
        super(Interpolate, self).__init__()
        self.size = size
        self.mode = mode
        self.align_corners = align_corners
        self.scale_factor=scale_factor

    def forward(self, x):
        if self.scale_factor == None:
            return F.interpolate(x, self.size, mode=self.mode, align_corners=self.align_corners)
        else:
            return F.interpolate(x, scale_factor = self.scale_factor, mode = self.mode, align_corners = self.align_corners)

and add two unit tests at the end:

@add_module_test(torch.float32, torch.device('cuda'), [(1, 256, 192, 512)])
def test_scale():
    return Interpolate(None,'bilinear',False,scale_factor=2)

@add_module_test(torch.float32, torch.device('cuda'), [(1, 256, 192, 512)])
def test_nn_scale():
    return nn.Upsample(scale_factor = 2, mode='bilinear',align_corners=False)

I get the following results:

| torch2trt.converters.interpolate.interpolate.test_scale | float32 | [(1, 256, 192, 512)] | {} | 0.00E+00 | 613 | 785 | 1.66 | 1.33 |
| torch2trt.converters.interpolate.interpolate.test_nn_scale | float32 | [(1, 256, 192, 512)] | {} | 2.99E+00 | 605 | 914 | 1.68 | 1.15 |
jasonliu19 commented 4 years ago

I have noticed the same error, which is strange since I thought nn.Upsample used interpolate under the hood.

SrivastavaKshitij commented 4 years ago

Ok. I have a fix for the problem. will create a pull request tomorrow.