Open kwikwag opened 1 month ago
I can confirm that this also occurs on Tensorflow 2.11 (which is the Android dependency version for flutter-tflite
) and 2.13.1 on Linux, but does not occur on 2.14.0rc0 or 2.17 (the current version installed on Colab).
I have narrowed it down to a problem with how MaxPool2D gets prepared:
import torch
from torch import nn
import ai_edge_torch
class Test(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
def forward(self, x):
x1 = self.m(x)
return torch.cat([x, x1], dim=1)
m = Test()
torch._dynamo.config.verbose = True
edge_model = ai_edge_torch.convert(m.eval(), sample_args=(torch.zeros((1, 3, 10, 10)), ))
edge_model.export('test.tflite')
Testing with Tensorflow 2.11.0 this fails:
import tensorflow.lite
interpreter = tensorflow.lite.Interpreter('test.tflite')
interpreter.allocate_tensors()
With:
RuntimeError: tensorflow/lite/kernels/concatenation.cc:158 t->dims->data[d] != t0->dims->data[d] (6 != 10)Node number 5 (CONCATENATION) failed to prepare.Failed to apply the default TensorFlow Lite delegate indexed at 0.
Which means any module that uses MaxPool2D with padding is destined to fail to run on flutter-tflite? I'll try to see if I can come up with a workaround.
The following workaround seems to work - the produced TFLite module reads just fine in Tensorflow 2.11.0:
import torch
from torch import nn
import ai_edge_torch
class MaxPool2dWorkaround(nn.Module):
def __init__(self):
super().__init__()
self.m1 = nn.ConstantPad2d(padding=((3,)*4), value=-float('inf'))
self.m2 = nn.MaxPool2d(kernel_size=5, stride=1, padding=0)
def forward(self, x):
x1 = self.m1(x)
x1 = x1[:, :, 1:-1, 1:-1]
x1 = self.m2(x1)
return torch.cat([x, x1], dim=1)
m = MaxPool2dWorkaround()
x = torch.zeros((1, 3, 10, 10))
print(m(x).shape)
torch._dynamo.config.verbose = True
edge_model = ai_edge_torch.convert(m.eval(), sample_args=(x, ))
edge_model.export('test.tflite')
Just for the record, other attempts were not fruitful:
class PadAndCropWorks(nn.Module):
def __init__(self):
super().__init__()
self.m1 = nn.ConstantPad2d(padding=((2,)*4), value=-float('inf'))
def forward(self, x):
x1 = self.m1(x)
return torch.cat([x, x1[:, :, 2:-2, 2:-2]], dim=1)
class MaxPool2DWithoutPaddingWorks(nn.Module):
def __init__(self):
super().__init__()
self.m1 = nn.MaxPool2d(kernel_size=1, stride=1, padding=0)
def forward(self, x):
x1 = self.m1(x)
return torch.cat([x, x1], dim=1)
class MaxPool2DWithoutCatWorks(nn.Module):
def __init__(self):
super().__init__()
self.m1 = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
def forward(self, x):
x1 = self.m1(x)
return x1
class JustPadWorksButBadOutput(nn.Module):
def __init__(self):
super().__init__()
self.m1 = nn.ConstantPad2d(padding=((2,)*4), value=-float('inf'))
def forward(self, x):
x1 = self.m1(x)
return x1
class JustPadWith3WorksButBadOutput(nn.Module):
def __init__(self):
super().__init__()
self.m1 = nn.ConstantPad2d(padding=((3,)*4), value=-float('inf'))
def forward(self, x):
x1 = self.m1(x)
return x1
class ReconstructMaxPool2dFails(nn.Module):
def __init__(self):
super().__init__()
self.m1 = nn.ConstantPad2d(padding=((2,)*4), value=-float('inf'))
self.m2 = nn.MaxPool2d(kernel_size=5, stride=1, padding=0)
def forward(self, x):
x1 = self.m1(x)
x1 = self.m2(x1)
return torch.cat([x, x1], dim=1)
class ReconstructMaxPool2dWithMulFails(nn.Module):
def __init__(self):
super().__init__()
self.m1 = nn.ConstantPad2d(padding=((2,)*4), value=-float('inf'))
self.m2 = nn.MaxPool2d(kernel_size=5, stride=1, padding=0)
def forward(self, x):
x1 = self.m1(x) * (1 - 1e-5)
x1 = self.m2(x1)
return torch.cat([x, x1], dim=1)
class ExtraPadAndCropWorksGoodOutput(nn.Module):
def __init__(self):
super().__init__()
self.m1 = nn.ConstantPad2d(padding=((3,)*4), value=-float('inf'))
def forward(self, x):
x1 = self.m1(x)
x1 = x1[:, :, 1:-1, 1:-1]
return x1
I'll try it on YOLOv6 and report...
Using the aforementioned workaround, the model works with TF 2.11 and thus with flutter-tflite. So I would suggest creating an issue for upgrading the TF library version.
On another note, the quantized model created with TF 2.17 is not backwards compatible with TF Lite 2.11, as it requires the hybrid transpose-conv (introduced at v2.17), but this is another issue.
Here is a complete example for seeing the failure and workaround on Colab.
%%bash
# the torch-xla version pinning is due to https://github.com/google-ai-edge/ai-edge-torch/issues/307
pip install -q ai-edge-torch torch-xla==2.4.0 virtualenv
[ -d venv ] || python -m virtualenv venv
source venv/bin/activate
pip install -q tensorflow==2.11.0 numpy==1.24.2
cat > test.sh <<TEST_SH
source venv/bin/activate
# ignore redundant warning outputs, see: https://stackoverflow.com/questions/35911252
export TF_CPP_MIN_LOG_LEVEL=3
python <<PYTHON
import tensorflow.lite
model = tensorflow.lite.Interpreter('test.tflite')
model.allocate_tensors()
PYTHON
TEST_SH
import traceback
import ai_edge_torch
import torch
from torch import nn
class MaxPool2d_TFLite_2_11_Compatible(nn.MaxPool2d):
"""_summary_
:param nn: _description_
:return: _description_
"""
@staticmethod
def from_other(other: nn.MaxPool2d):
return MaxPool2d_TFLite_2_11_Compatible(
kernel_size=other.kernel_size,
stride=other.stride,
padding=other.padding,
dilation=other.dilation,
return_indices=other.return_indices,
ceil_mode=other.ceil_mode,
)
def forward(self, input: torch.Tensor):
pad = self.padding
if isinstance(pad, int):
pad = (pad + 1,) * 4
else:
pad = (pad[0] + 1, pad[0] + 1, pad[1] + 1, pad[1] + 1)
x = nn.functional.pad(input, pad=pad, value=-float('inf'))
x = x[..., 1:-1, 1:-1]
x = nn.functional.max_pool2d(
x, self.kernel_size, self.stride,
padding=0, dilation=self.dilation, ceil_mode=self.ceil_mode,
return_indices=self.return_indices)
return x
class UsingNNMaxPool2d(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.MaxPool2d(5, 1, 2)
def forward(self, x):
return torch.cat([x, self.m(x)], dim=1)
class UsingMaxPool2dWorkaround(nn.Module):
def __init__(self):
super().__init__()
self.m = MaxPool2d_TFLite_2_11_Compatible.from_other(nn.MaxPool2d(5, 1, 2))
def forward(self, x):
return torch.cat([x, self.m(x)], dim=1)
class ControlShouldAlwaysWork(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.cat([x, x], dim=1)
x = torch.zeros((1, 3, 10, 10))
torch._dynamo.config.verbose = True
for M in [ControlShouldAlwaysWork, UsingNNMaxPool2d, UsingMaxPool2dWorkaround]:
m = M().eval()
try:
print(f"<<< Testing {m.__class__.__name__}: ")
edge_model = ai_edge_torch.convert(m, sample_args=(x, ))
!rm -rf test.tflite
edge_model.export('test.tflite')
!bash test.sh
print(">>> Success" if _exit_code == 0 else ">>> Failed with exit code")
except Exception:
traceback.print_exc()
print(">>> Failed with exception")
<<< Testing ControlShouldAlwaysWork:
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
>>> Success
<<< Testing UsingNNMaxPool2d:
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
Traceback (most recent call last):
File "<stdin>", line 4, in <module>
File "/content/venv/lib/python3.10/site-packages/tensorflow/lite/python/interpreter.py", line 513, in allocate_tensors
return self._interpreter.AllocateTensors()
RuntimeError: tensorflow/lite/kernels/concatenation.cc:158 t->dims->data[d] != t0->dims->data[d] (6 != 10)Node number 4 (CONCATENATION) failed to prepare.Failed to apply the default TensorFlow Lite delegate indexed at 0.
>>> Failed with exit code
<<< Testing UsingMaxPool2dWorkaround:
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
>>> Success
I am having a problem with a specific TFLite model: my model works on other platforms (TFLite on Linux) but fails using the Flutter TFLite package (
tflite_flutter: ^0.11.0
) in my Android simulator. The model in question is YOLOv6-seg.Using the recipe for PyTorch to TFLite, I export a
.tflite
model for YOLOv6-seg (see notebook).When running with Python TFLite (using the
ai_edge_litert
package) everything seems to be working fine.However, when using the the Flutter package
tflite_flutter: ^0.11.0
, running on Android, the call toawait Interpreter.fromAsset(...)
fails with the following log message:This happens both for the quantized and normal model. The node number is 63 in the quantized model, and 40 in the normal model. The node corresponds to the
CSPSPPFModule
class, which is a rather simple combination of concat, add, conv2d, BN, maxpool and ReLU.For the model, I use the
yolov6-seg
branch. There is something slightly weird about the YOLOv6 code - the original authors added code to silence warnings:This in turn broke
torch.export.export()
. So for the export to work, I had to remove the warning filter. It just so happens that this issue resides in the aforementioned classCSPSPPFModule
. I do not have an in-depth understanding of the model, but perhaps the warning the original authors wanted to silence has something to do with this failure. However, I am not getting any actual warning in PyTorch or in TFLite on Python, and inference using the exported model works perfectly using TFLite on Python. Again, the exact same model fails on Flutter Android.Here is how it looks like in the model explorer:
I assume this is a bug in the underlying TFLite implementation. However, since this is obviously working on other platforms (namely on Python on Colab), I am assuming it's something that's been fixed in subsequent implementations and thus filing it here in the Flutter TFLite package.
I can provide the quantized model, or you can run the notebook linked to above.