NVIDIA-AI-IOT / torch2trt

An easy to use PyTorch to TensorRT converter
MIT License
4.55k stars 671 forks source link

Inconsistent inference results between PyTorch and converted TensorRT model with BatchNorm or InstanceNorm operator #910

Open hongliyu0716 opened 9 months ago

hongliyu0716 commented 9 months ago

Description:

I'm experiencing a discrepancy between the inference results of PyTorch model and the TensorRT model obtained by converting it using the torch2trt tool.

Reproduce

  1. BatchNorm
    
    import torch
    from torch.nn import Module
    from torch2trt import torch2trt

para_0 = torch.randn([2, 3, 4, 4, 4], dtype=torch.float32).cuda() para_1 = torch.randn([3], dtype=torch.float32).cuda() para_2 = torch.randn([3], dtype=torch.float32).cuda() para_3 = torch.randn([3], dtype=torch.float32).cuda() para_4 = torch.randn([3], dtype=torch.float32).cuda() para_5 = True para_6 = 0.00026441036488630354 para_7 = 0.001 class batch_norm(Module): def forward(self, *args): return torch.nn.functional.batch_norm(args[0], para_1,para_2,para_3,para_4,para_5,para_6,para_7,) model = batch_norm().float().eval().cuda() model_trt = torch2trt(model, [para_0]) output = model(para_0) output_trt = model_trt(para_0) print(torch.max(torch.abs(output - output_trt)))

2. InstanceNorm
```python
import torch
from torch.nn import Module
from torch2trt import torch2trt

para_0 = torch.randn([2, 3, 4, 4, 4], dtype=torch.float32).cuda()
para_1 = torch.randn([3], dtype=torch.float32).cuda()
para_2 = torch.randn([3], dtype=torch.float32).cuda()
para_3 = None
para_4 = None
para_5 = False
para_6 = 0.3
para_7 = 0.001
class instance_norm(Module):
    def forward(self, *args):
        return torch.nn.functional.instance_norm(args[0], para_1,para_2,para_3,para_4,para_5,para_6,para_7,)
model = instance_norm().float().eval().cuda()
model_trt = torch2trt(model, [para_0])

output = model(para_0)
output_trt = model_trt(para_0)
print(torch.max(torch.abs(output - output_trt)))

Environment

edwardnguyen1705 commented 2 months ago

Hi @hongliyu0716 , Have you solved this issue yet?