tensorflow / flutter-tflite

Apache License 2.0
561 stars 133 forks source link

TF 2.11.0 MaxPool2D padding incorrect dim calculation bug prevents models from being loaded #258

Open kwikwag opened 1 month ago

kwikwag commented 1 month ago

I am having a problem with a specific TFLite model: my model works on other platforms (TFLite on Linux) but fails using the Flutter TFLite package ( tflite_flutter: ^0.11.0) in my Android simulator. The model in question is YOLOv6-seg.

Using the recipe for PyTorch to TFLite, I export a .tflite model for YOLOv6-seg (see notebook).

When running with Python TFLite (using the ai_edge_litert package) everything seems to be working fine.

However, when using the the Flutter package tflite_flutter: ^0.11.0, running on Android, the call to await Interpreter.fromAsset(...) fails with the following log message:

E/tflite  (13665): tensorflow/lite/kernels/concatenation.cc:158 t->dims->data[d] != t0->dims->data[d] (16 != 20)
E/tflite  (13665): Node number 63 (CONCATENATION) failed to prepare.

This happens both for the quantized and normal model. The node number is 63 in the quantized model, and 40 in the normal model. The node corresponds to the CSPSPPFModule class, which is a rather simple combination of concat, add, conv2d, BN, maxpool and ReLU.

For the model, I use the yolov6-seg branch. There is something slightly weird about the YOLOv6 code - the original authors added code to silence warnings:

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    ...

This in turn broke torch.export.export(). So for the export to work, I had to remove the warning filter. It just so happens that this issue resides in the aforementioned class CSPSPPFModule. I do not have an in-depth understanding of the model, but perhaps the warning the original authors wanted to silence has something to do with this failure. However, I am not getting any actual warning in PyTorch or in TFLite on Python, and inference using the exported model works perfectly using TFLite on Python. Again, the exact same model fails on Flutter Android.

Here is how it looks like in the model explorer:

image

I assume this is a bug in the underlying TFLite implementation. However, since this is obviously working on other platforms (namely on Python on Colab), I am assuming it's something that's been fixed in subsequent implementations and thus filing it here in the Flutter TFLite package.

I can provide the quantized model, or you can run the notebook linked to above.

kwikwag commented 1 month ago

I can confirm that this also occurs on Tensorflow 2.11 (which is the Android dependency version for flutter-tflite) and 2.13.1 on Linux, but does not occur on 2.14.0rc0 or 2.17 (the current version installed on Colab).

kwikwag commented 1 month ago

I have narrowed it down to a problem with how MaxPool2D gets prepared:

import torch
from torch import nn

import ai_edge_torch

class Test(nn.Module):
  def __init__(self):
    super().__init__()
    self.m = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
  def forward(self, x):
    x1 = self.m(x)
    return torch.cat([x, x1], dim=1)

m = Test()

torch._dynamo.config.verbose = True
edge_model = ai_edge_torch.convert(m.eval(), sample_args=(torch.zeros((1, 3, 10, 10)), ))
edge_model.export('test.tflite')

Testing with Tensorflow 2.11.0 this fails:

import tensorflow.lite
interpreter = tensorflow.lite.Interpreter('test.tflite')
interpreter.allocate_tensors()

With:

RuntimeError: tensorflow/lite/kernels/concatenation.cc:158 t->dims->data[d] != t0->dims->data[d] (6 != 10)Node number 5 (CONCATENATION) failed to prepare.Failed to apply the default TensorFlow Lite delegate indexed at 0.

Which means any module that uses MaxPool2D with padding is destined to fail to run on flutter-tflite? I'll try to see if I can come up with a workaround.

kwikwag commented 1 month ago

The following workaround seems to work - the produced TFLite module reads just fine in Tensorflow 2.11.0:

import torch
from torch import nn

import ai_edge_torch

class MaxPool2dWorkaround(nn.Module):
  def __init__(self):
    super().__init__()
    self.m1 = nn.ConstantPad2d(padding=((3,)*4), value=-float('inf'))
    self.m2 = nn.MaxPool2d(kernel_size=5, stride=1, padding=0)

  def forward(self, x):
    x1 = self.m1(x)
    x1 = x1[:, :, 1:-1, 1:-1]
    x1 = self.m2(x1)
    return torch.cat([x, x1], dim=1)

m = MaxPool2dWorkaround()
x = torch.zeros((1, 3, 10, 10))
print(m(x).shape)

torch._dynamo.config.verbose = True
edge_model = ai_edge_torch.convert(m.eval(), sample_args=(x, ))
edge_model.export('test.tflite')

Just for the record, other attempts were not fruitful:


class PadAndCropWorks(nn.Module):
  def __init__(self):
    super().__init__()
    self.m1 = nn.ConstantPad2d(padding=((2,)*4), value=-float('inf'))

  def forward(self, x):
    x1 = self.m1(x)
    return torch.cat([x, x1[:, :, 2:-2, 2:-2]], dim=1)

class MaxPool2DWithoutPaddingWorks(nn.Module):
  def __init__(self):
    super().__init__()
    self.m1 = nn.MaxPool2d(kernel_size=1, stride=1, padding=0)

  def forward(self, x):
    x1 = self.m1(x)
    return torch.cat([x, x1], dim=1)

class MaxPool2DWithoutCatWorks(nn.Module):
  def __init__(self):
    super().__init__()
    self.m1 = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)

  def forward(self, x):
    x1 = self.m1(x)
    return x1

class JustPadWorksButBadOutput(nn.Module):
  def __init__(self):
    super().__init__()
    self.m1 = nn.ConstantPad2d(padding=((2,)*4), value=-float('inf'))

  def forward(self, x):
    x1 = self.m1(x)
    return x1

class JustPadWith3WorksButBadOutput(nn.Module):
  def __init__(self):
    super().__init__()
    self.m1 = nn.ConstantPad2d(padding=((3,)*4), value=-float('inf'))

  def forward(self, x):
    x1 = self.m1(x)
    return x1

class ReconstructMaxPool2dFails(nn.Module):
  def __init__(self):
    super().__init__()
    self.m1 = nn.ConstantPad2d(padding=((2,)*4), value=-float('inf'))
    self.m2 = nn.MaxPool2d(kernel_size=5, stride=1, padding=0)

  def forward(self, x):
    x1 = self.m1(x)
    x1 = self.m2(x1)
    return torch.cat([x, x1], dim=1)

class ReconstructMaxPool2dWithMulFails(nn.Module):
  def __init__(self):
    super().__init__()
    self.m1 = nn.ConstantPad2d(padding=((2,)*4), value=-float('inf'))
    self.m2 = nn.MaxPool2d(kernel_size=5, stride=1, padding=0)

  def forward(self, x):
    x1 = self.m1(x) * (1 - 1e-5)
    x1 = self.m2(x1)
    return torch.cat([x, x1], dim=1)

class ExtraPadAndCropWorksGoodOutput(nn.Module):
  def __init__(self):
    super().__init__()
    self.m1 = nn.ConstantPad2d(padding=((3,)*4), value=-float('inf'))

  def forward(self, x):
    x1 = self.m1(x)
    x1 = x1[:, :, 1:-1, 1:-1]
    return x1

I'll try it on YOLOv6 and report...

kwikwag commented 1 month ago

Using the aforementioned workaround, the model works with TF 2.11 and thus with flutter-tflite. So I would suggest creating an issue for upgrading the TF library version.

On another note, the quantized model created with TF 2.17 is not backwards compatible with TF Lite 2.11, as it requires the hybrid transpose-conv (introduced at v2.17), but this is another issue.

kwikwag commented 1 month ago

Here is a complete example for seeing the failure and workaround on Colab.

Cell 1

%%bash
# the torch-xla version pinning is due to https://github.com/google-ai-edge/ai-edge-torch/issues/307
pip install -q ai-edge-torch torch-xla==2.4.0 virtualenv

[ -d venv ] || python -m virtualenv venv

source venv/bin/activate
pip install -q tensorflow==2.11.0 numpy==1.24.2

cat > test.sh <<TEST_SH
source venv/bin/activate

# ignore redundant warning outputs, see: https://stackoverflow.com/questions/35911252
export TF_CPP_MIN_LOG_LEVEL=3

python <<PYTHON
import tensorflow.lite

model = tensorflow.lite.Interpreter('test.tflite')
model.allocate_tensors()
PYTHON

TEST_SH

Cell 2

import traceback

import ai_edge_torch
import torch
from torch import nn

class MaxPool2d_TFLite_2_11_Compatible(nn.MaxPool2d):
    """_summary_

    :param nn: _description_
    :return: _description_
    """
    @staticmethod
    def from_other(other: nn.MaxPool2d):
        return MaxPool2d_TFLite_2_11_Compatible(
            kernel_size=other.kernel_size,
            stride=other.stride,
            padding=other.padding,
            dilation=other.dilation,
            return_indices=other.return_indices,
            ceil_mode=other.ceil_mode,
        )

    def forward(self, input: torch.Tensor):
        pad = self.padding
        if isinstance(pad, int):
            pad = (pad + 1,) * 4
        else:
            pad = (pad[0] + 1, pad[0] + 1, pad[1] + 1, pad[1] + 1)
        x = nn.functional.pad(input, pad=pad, value=-float('inf'))
        x = x[..., 1:-1, 1:-1]
        x = nn.functional.max_pool2d(
          x, self.kernel_size, self.stride,
          padding=0, dilation=self.dilation, ceil_mode=self.ceil_mode,
          return_indices=self.return_indices)
        return x

class UsingNNMaxPool2d(nn.Module):
  def __init__(self):
    super().__init__()
    self.m = nn.MaxPool2d(5, 1, 2)

  def forward(self, x):
    return torch.cat([x, self.m(x)], dim=1)

class UsingMaxPool2dWorkaround(nn.Module):
  def __init__(self):
    super().__init__()

    self.m = MaxPool2d_TFLite_2_11_Compatible.from_other(nn.MaxPool2d(5, 1, 2))

  def forward(self, x):
    return torch.cat([x, self.m(x)], dim=1)

class ControlShouldAlwaysWork(nn.Module):
  def __init__(self):
    super().__init__()
  def forward(self, x):
    return torch.cat([x, x], dim=1)

x = torch.zeros((1, 3, 10, 10))

torch._dynamo.config.verbose = True
for M in [ControlShouldAlwaysWork, UsingNNMaxPool2d, UsingMaxPool2dWorkaround]:
  m = M().eval()
  try:
    print(f"<<< Testing {m.__class__.__name__}: ")
    edge_model = ai_edge_torch.convert(m, sample_args=(x, ))
    !rm -rf test.tflite
    edge_model.export('test.tflite')
    !bash test.sh
    print(">>> Success" if _exit_code == 0 else ">>> Failed with exit code")
  except Exception:
    traceback.print_exc()
    print(">>> Failed with exception")

Output:

<<< Testing ControlShouldAlwaysWork: 
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
>>> Success
<<< Testing UsingNNMaxPool2d: 
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
Traceback (most recent call last):
  File "<stdin>", line 4, in <module>
  File "/content/venv/lib/python3.10/site-packages/tensorflow/lite/python/interpreter.py", line 513, in allocate_tensors
    return self._interpreter.AllocateTensors()
RuntimeError: tensorflow/lite/kernels/concatenation.cc:158 t->dims->data[d] != t0->dims->data[d] (6 != 10)Node number 4 (CONCATENATION) failed to prepare.Failed to apply the default TensorFlow Lite delegate indexed at 0.
>>> Failed with exit code
<<< Testing UsingMaxPool2dWorkaround: 
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
>>> Success