tensorflow / flutter-tflite

Apache License 2.0
537 stars 126 forks source link

TF 2.11.0 MaxPool2D padding incorrect dim calculation bug prevents models from being loaded #258

Open kwikwag opened 3 days ago

kwikwag commented 3 days ago

I am having a problem with a specific TFLite model: my model works on other platforms (TFLite on Linux) but fails using the Flutter TFLite package ( tflite_flutter: ^0.11.0) in my Android simulator. The model in question is YOLOv6-seg.

Using the recipe for PyTorch to TFLite, I export a .tflite model for YOLOv6-seg (see notebook).

When running with Python TFLite (using the ai_edge_litert package) everything seems to be working fine.

However, when using the the Flutter package tflite_flutter: ^0.11.0, running on Android, the call to await Interpreter.fromAsset(...) fails with the following log message:

E/tflite  (13665): tensorflow/lite/kernels/concatenation.cc:158 t->dims->data[d] != t0->dims->data[d] (16 != 20)
E/tflite  (13665): Node number 63 (CONCATENATION) failed to prepare.

This happens both for the quantized and normal model. The node number is 63 in the quantized model, and 40 in the normal model. The node corresponds to the CSPSPPFModule class, which is a rather simple combination of concat, add, conv2d, BN, maxpool and ReLU.

For the model, I use the yolov6-seg branch. There is something slightly weird about the YOLOv6 code - the original authors added code to silence warnings:

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    ...

This in turn broke torch.export.export(). So for the export to work, I had to remove the warning filter. It just so happens that this issue resides in the aforementioned class CSPSPPFModule. I do not have an in-depth understanding of the model, but perhaps the warning the original authors wanted to silence has something to do with this failure. However, I am not getting any actual warning in PyTorch or in TFLite on Python, and inference using the exported model works perfectly using TFLite on Python. Again, the exact same model fails on Flutter Android.

Here is how it looks like in the model explorer:

image

I assume this is a bug in the underlying TFLite implementation. However, since this is obviously working on other platforms (namely on Python on Colab), I am assuming it's something that's been fixed in subsequent implementations and thus filing it here in the Flutter TFLite package.

I can provide the quantized model, or you can run the notebook linked to above.

kwikwag commented 3 days ago

I can confirm that this also occurs on Tensorflow 2.11 (which is the Android dependency version for flutter-tflite) and 2.13.1 on Linux, but does not occur on 2.14.0rc0 or 2.17 (the current version installed on Colab).

kwikwag commented 3 days ago

I have narrowed it down to a problem with how MaxPool2D gets prepared:

import torch
from torch import nn

import ai_edge_torch

class Test(nn.Module):
  def __init__(self):
    super().__init__()
    self.m = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
  def forward(self, x):
    x1 = self.m(x)
    return torch.cat([x, x1], dim=1)

m = Test()

torch._dynamo.config.verbose = True
edge_model = ai_edge_torch.convert(m.eval(), sample_args=(torch.zeros((1, 3, 10, 10)), ))
edge_model.export('test.tflite')

Testing with Tensorflow 2.11.0 this fails:

import tensorflow.lite
interpreter = tensorflow.lite.Interpreter('test.tflite')
interpreter.allocate_tensors()

With:

RuntimeError: tensorflow/lite/kernels/concatenation.cc:158 t->dims->data[d] != t0->dims->data[d] (6 != 10)Node number 5 (CONCATENATION) failed to prepare.Failed to apply the default TensorFlow Lite delegate indexed at 0.

Which means any module that uses MaxPool2D with padding is destined to fail to run on flutter-tflite? I'll try to see if I can come up with a workaround.

kwikwag commented 3 days ago

The following workaround seems to work - the produced TFLite module reads just fine in Tensorflow 2.11.0:

import torch
from torch import nn

import ai_edge_torch

class MaxPool2dWorkaround(nn.Module):
  def __init__(self):
    super().__init__()
    self.m1 = nn.ConstantPad2d(padding=((3,)*4), value=-float('inf'))
    self.m2 = nn.MaxPool2d(kernel_size=5, stride=1, padding=0)

  def forward(self, x):
    x1 = self.m1(x)
    x1 = x1[:, :, 1:-1, 1:-1]
    x1 = self.m2(x1)
    return torch.cat([x, x1], dim=1)

m = MaxPool2dWorkaround()
x = torch.zeros((1, 3, 10, 10))
print(m(x).shape)

torch._dynamo.config.verbose = True
edge_model = ai_edge_torch.convert(m.eval(), sample_args=(x, ))
edge_model.export('test.tflite')

Just for the record, other attempts were not fruitful:


class PadAndCropWorks(nn.Module):
  def __init__(self):
    super().__init__()
    self.m1 = nn.ConstantPad2d(padding=((2,)*4), value=-float('inf'))

  def forward(self, x):
    x1 = self.m1(x)
    return torch.cat([x, x1[:, :, 2:-2, 2:-2]], dim=1)

class MaxPool2DWithoutPaddingWorks(nn.Module):
  def __init__(self):
    super().__init__()
    self.m1 = nn.MaxPool2d(kernel_size=1, stride=1, padding=0)

  def forward(self, x):
    x1 = self.m1(x)
    return torch.cat([x, x1], dim=1)

class MaxPool2DWithoutCatWorks(nn.Module):
  def __init__(self):
    super().__init__()
    self.m1 = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)

  def forward(self, x):
    x1 = self.m1(x)
    return x1

class JustPadWorksButBadOutput(nn.Module):
  def __init__(self):
    super().__init__()
    self.m1 = nn.ConstantPad2d(padding=((2,)*4), value=-float('inf'))

  def forward(self, x):
    x1 = self.m1(x)
    return x1

class JustPadWith3WorksButBadOutput(nn.Module):
  def __init__(self):
    super().__init__()
    self.m1 = nn.ConstantPad2d(padding=((3,)*4), value=-float('inf'))

  def forward(self, x):
    x1 = self.m1(x)
    return x1

class ReconstructMaxPool2dFails(nn.Module):
  def __init__(self):
    super().__init__()
    self.m1 = nn.ConstantPad2d(padding=((2,)*4), value=-float('inf'))
    self.m2 = nn.MaxPool2d(kernel_size=5, stride=1, padding=0)

  def forward(self, x):
    x1 = self.m1(x)
    x1 = self.m2(x1)
    return torch.cat([x, x1], dim=1)

class ReconstructMaxPool2dWithMulFails(nn.Module):
  def __init__(self):
    super().__init__()
    self.m1 = nn.ConstantPad2d(padding=((2,)*4), value=-float('inf'))
    self.m2 = nn.MaxPool2d(kernel_size=5, stride=1, padding=0)

  def forward(self, x):
    x1 = self.m1(x) * (1 - 1e-5)
    x1 = self.m2(x1)
    return torch.cat([x, x1], dim=1)

class ExtraPadAndCropWorksGoodOutput(nn.Module):
  def __init__(self):
    super().__init__()
    self.m1 = nn.ConstantPad2d(padding=((3,)*4), value=-float('inf'))

  def forward(self, x):
    x1 = self.m1(x)
    x1 = x1[:, :, 1:-1, 1:-1]
    return x1

I'll try it on YOLOv6 and report...

kwikwag commented 5 hours ago

Using the aforementioned workaround, the model works with TF 2.11 and thus with flutter-tflite. So I would suggest creating an issue for upgrading the TF library version.

On another note, the quantized model created with TF 2.17 is not backwards compatible with TF Lite 2.11, as it requires the hybrid transpose-conv (introduced at v2.17), but this is another issue.