NVIDIA-AI-IOT / torch2trt

An easy to use PyTorch to TensorRT converter
MIT License
4.6k stars 675 forks source link

Instance normalization with FP16 has large errors #776

Open ivan94fi opened 2 years ago

ivan94fi commented 2 years ago

I have a problem with instance normalization, the model outputs diverge substantially when using the tensorrt model with float16 precision.

Versions:

Example code:

from torch import nn
import torch
import random
import numpy as np

from torch2trt import torch2trt

class Net(nn.Module):
    def __init__(self, nc):
        super().__init__()
        self.conv = nn.Conv2d(nc, nc, 3)
        self.in_norm = nn.InstanceNorm2d(nc, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)

    def forward(self, x):
        x = self.conv(x)
        x = self.in_norm(x)

        return x

seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

nc = 32
input_data = torch.rand(1, nc, 256, 256)

with torch.no_grad():
    x = input_data.cuda()

    # model = nn.Sequential(*[Net(nc) for _ in range(32)])
    model = Net(nc)
    model = model.cuda()
    model.eval()

    out = model(x)

    trt_model = torch2trt(model, [x], fp16_mode=True)
    trt_model.eval()
    trtout = trt_model(x)

    print(f"greatest difference trt (fp16): {(out - trtout).max().item()}")

    trt_model = torch2trt(model, [x])
    trt_model.eval()
    trtout = trt_model(x)

    print(f"greatest difference trt (fp32): {(out - trtout).max().item()}")

This is the output I get from this script:

greatest difference trt (fp16): 0.020292997360229492
greatest difference trt (fp32): 1.6689300537109375e-06

The errors are too big in float16 mode, especially considering this is just one instance normalization and there are many in my model, so errors are propagated and become larger.

This may be caused by the fact that instance norm uses float16 precision and this causes numerical errors.

deephog commented 2 years ago

I'm having a similar issue, don't know if it is the FP16 overflow issue or InstanceNorm issue.

GabrieldeBlois commented 2 years ago

Hi,

I'm having a similar issue trying to convert to FP16. The conversion was working fine with previous versions of pytorch and this started to happen when I upgraded torch from 1.8 to 1.10.

So one quick fix for you would be to downgrade to earlier version of pytorch and it might solve your problem. It is not a long term solution though because we need to upgrade at some point ! We will need to find the root of this problem aniway..

For the moment I'm trying to read the changelog to see if there is something critical that changed.

ivan94fi commented 2 years ago

Thanks for these information, however downgrading pythorch is not an option for us.

Any updates from the developers on this?

Thank you