traveller59 / torch2trt

convert torch module to tensorrt network or tvm function
MIT License
89 stars 19 forks source link

Fix multi inputs bug. #4

Closed GeoffreyChen777 closed 5 years ago

GeoffreyChen777 commented 5 years ago

Here is the demo to reproduce the bug:

import torch
import torch.nn as nn
import torchvision
import torch2trt

max_batchsize = 1
max_trt_workspace = 1 << 30 # 1GB

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(3, 64, 3, 1, 1)
    def forward(self, x, y):
        x = self.conv(x)
        y = self.conv(y)
        return x + y

net = Net().cuda().eval()

inputs1 = torch.rand(1, 3, 4, 4).float().cuda()
inputs2 = torch.rand(1, 3, 4, 4).float().cuda()
net_trt = torch2trt.TensorRTModuleWrapper(net, max_batchsize, max_trt_workspace, param_exclude=".*AuxLogits.*").cuda().eval()
out_ref = net(inputs1, inputs2)
out = net_trt(inputs1, inputs2)
print("Diff: {}".format((out_ref-out).abs().sum()))
