google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
228 stars 26 forks source link

The result of ConvTranspose2d without bias followed by BatchNorm2d is incorrect. #28

Open AkiSakurai opened 1 month ago

AkiSakurai commented 1 month ago

Description of the bug:

from torch import nn
import torch
import ai_edge_torch
import numpy as np
channels = 2
size = 2
sample_bias = 15
model = nn.Sequential(
            nn.ConvTranspose2d(channels, channels, 1, stride=2,  bias=False),
            nn.BatchNorm2d(channels),
        )
model(torch.rand(1, channels ,size ,size) + sample_bias)
model.eval()
sample = torch.rand(1, channels ,size ,size) 

model_edge = ai_edge_torch.convert(model, (sample,))

print("converted", model_edge(sample))
print("original", model(sample).detach().numpy())
print("maximum absolute difference",np.max(np.abs(model_edge(sample) - model(sample).detach().numpy())))
converted [[[[-0.44367838 -0.16322662 -0.53445685]
   [-0.16322662 -0.16322662 -0.16322662]
   [-0.13195059 -0.16322662 -0.44492   ]]

  [[-0.35995242 -0.23823708 -0.32022056]
   [-0.23823708 -0.23823708 -0.23823708]
   [-0.5477077  -0.23823708 -0.35932225]]]]
original [[[[-0.11722518  0.16322662 -0.20800363]
   [ 0.16322662  0.16322662  0.16322662]
   [ 0.19450262  0.16322662 -0.11846679]]

  [[ 0.11652175  0.23823708  0.1562536 ]
   [ 0.23823708  0.23823708  0.23823708]
   [-0.07123353  0.23823708  0.11715191]]]]
maximum absolute difference 0.47647417

Actual vs expected behavior:

No response

Any other information you'd like to share?

No response

talumbau commented 1 week ago

Hi,

Thanks for filing this bug! I can confirm that problem exists when converting ConvTranspose2d in regards to the bias tensor. If bias=False as you have above, a garbage bias value (AFAICT) is propagated through the edge model conversion and is part of the inference calculation. If bias=True when running your script, the models behave as expected (agreement to ~1e-8 in my tests). I will update this bug again when a fix has landed.