google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
327 stars 45 forks source link

Ai-edge-torch input/ouput format does not match ImageClassifier model compatibility requirements #91

Closed RubensZimbres closed 3 months ago

RubensZimbres commented 3 months ago

In ImageClassifier model compatibility requirements () it is said that the input format for the tflite exported model must be image input of size [batch x height x width x channels]

However, the code in the repo states that image input of size [batch x channels x height x width]

resnet18 = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1).eval()
sample_input = (torch.rand((1,3,224, 224),dtype=torch.float32),)
edge_model = ai_edge_torch.signature("input1", resnet18, sample_input).convert(resnet18, sample_input)
resnet18(*sample_input).shape
edge_model(*sample_input).shape
edge_model.export("/home/user/ai-edge/quasee/resnet18.tflite")

I'm using Python 3.10 in Anaconda plus:

ai-edge-torch-nightly          0.2.0.dev20240717
torch                          2.4.0.dev20240429+cpu
torch-xla                      2.4.0+git174f407
torchaudio                     2.2.0.dev20240429+cpu
torchvision                    0.19.0.dev20240429+cpu

My first question is: when Tensorflow converts the existing resnet18 model, does it automatically reshape the input format to tflite version?

Because I am adding a customized .tflite (with ai-edge-torch) to a MediaPipe Image Classifier, and it does not work. I exported via TFL, with and withour metadata, quantized or not. None of them work, maybe because of input shape.

const createImageClassifier = async () => {
  const vision = await FilesetResolver.forVisionTasks(
    "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@0.10.14 /wasm"
  );
  imageClassifier = await ImageClassifier.createFromOptions(vision, {
    baseOptions: {
      modelAssetPath: `https://storage.googleapis.com/customized/tflite/resnet18.tflite`
      // NOTE: For this demo, we keep the default CPU delegate.
      // working one https://storage.googleapis.com/mediapipe-models/image_classifier/efficientnet_lite0/float32/1/efficientnet_lite0.tflite
      // 
    },
    maxResults: 1,
    runningMode: runningMode
  });

  // Show demo section now model is ready to use.
  demosSection.classList.remove("invisible");
};
createImageClassifier();

This means that ai-edge-torch successfully converts PyTorch models to .tflite. And MediaPipe uses .tflite in its image classification inference. However, ai-edge-torch library inputs and outputs a (1,Height,Width, Channels) (https://github.com/google-ai-edge/ai-edge-torch), but MediaPipe ImageClassifier only works with (1, Height, Width, Channels) (https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_classifier). So my idea was to use a customized classifier in MediaPipe, but it looks like both products don't talk to each other.

I tried:

resnet18 = torchvision.models.resnet18(pretrained=True).eval()
class PermuteLayer(nn.Module):
    def __init__(self):
        super(PermuteLayer, self).__init__()
    def forward(self, x):
        return x.permute(0, 3, 1, 2)
class CustomResNet(nn.Module):
    def __init__(self, resnet):
        super(CustomResNet, self).__init__()
        self.permute_layer = PermuteLayer()
        self.resnet = resnet
    def forward(self, x):
        x = self.permute_layer(x)
        print(x.shape)
        x = self.resnet(x)
        return x
edge_model = CustomResNet(resnet18)
sample_input = (torch.rand((1,224,224,3),dtype=torch.float32),)
edge_model(*sample_input)

It solves the input incompatibility, but tflite still does not work on MediaPipe, even with signature and metadata. Error in the web page, for uint8:

Error: INVALID_ARGUMENT: Classification tflite models are assumed to have a single subgraph.; Initialize was not ok; StartGraph failed

Screenshot from 2024-07-16 21-58-56

I also tried with float32 and it didn't work. I get the following error in MediaPipe interface:

Error: INVALID_ARGUMENT: Classification tflite models are assumed to have a single subgraph.; Initialize was not ok; StartGraph failed

UPDATE:

I was able to make it work without errors with this code:

resnet18 = torchvision.models.resnet18(pretrained=True).eval()

class PermuteInput(nn.Module):
    def __init__(self):
        super(PermuteInput, self).__init__()

    def forward(self, x):
        # Permute from (batch, height, width, channels) to (batch, channels, height, width)
        return x.permute(0, 3, 1, 2).float()

resnet18 = models.resnet18(pretrained=True)

resnet18_with_reshape = nn.Sequential(
    PermuteInput(),
    resnet18
)

edge_model = resnet18_with_reshape.eval()

sample_input = (torch.randint(0, 256, (1, 224, 224, 3), dtype=torch.uint8),)

resnet18_with_reshape(*sample_input)

edge_model = ai_edge_torch.convert(edge_model, sample_input)

I uploaded the model to https://netron.app/ and got this:

Screenshot from 2024-07-17 19-16-14

However, there is no classification, given that I am unable to add metadata to the model (resnet_labels.txt). When I do this, MediaPipe does not accept the tflite model. Other issue is that ai-edge-torch runs on Tensorflow 1.27.0 and the metadata writer only runs on Tensorflow 1.23.0. So, 2 environments are neccessay, what is counter productive.

Screenshot from 2024-07-17 14-43-11

pkgoogle commented 3 months ago

Hi @RubensZimbres, seems like there are a couple of things going on here ...

My first question is: when Tensorflow converts the existing resnet18 model, does it automatically reshape the input format to tflite version?

Short answer: yes, here's my script:

import ai_edge_torch
import torch
from torch import nn
import torchvision

resnet18 = torchvision.models.resnet18(pretrained=True)

class PermuteInput(nn.Module):
    def __init__(self):
        super(PermuteInput, self).__init__()

    def forward(self, x):
        # Permute from (batch, height, width, channels) to (batch, channels, height, width)
        return x.permute(0, 3, 1, 2).float()

resnet18 = torchvision.models.resnet18(pretrained=True)

resnet18_with_reshape = nn.Sequential(
    PermuteInput(),
    resnet18
)

sample_input = (torch.randint(0, 256, (1, 224, 224, 3), dtype=torch.uint8),)

resnet18_permuted = ai_edge_torch.convert(resnet18_with_reshape.eval(), sample_input)
resnet18_permuted.export("resnet18_permuted.tflite")

sample_input_2 = (torch.randint(0, 256, (1, 3, 224, 224), dtype=torch.float32),)

edge_resnet18 = ai_edge_torch.convert(resnet18.eval(), sample_input_2)
edge_resnet18.export("resnet18.tflite")

Without your manual adjustment this is how it looks like: input = (1, 3, 224, 224) \<inserted transpose> = (1, 224, 224, 3) image

With your manual adjustment: input = (1, 224, 224, 3) \<manual transpose> = (1, 3, 224, 224) image

Using our model-explorer tool here: https://huggingface.co/spaces/1aurent/model-explorer, https://github.com/google-ai-edge/model-explorer

You are correct this looks like an integration issue... I think media pipe reads the "GraphInputs" nodes and sees that the original follows the PT convention which causes the failure. I tried loading your modified model in Media Pipe Studio and run into the issue you are speaking of but I believe that's a MediaPipe issue since AET does the conversion properly as far as I can tell. Please create a MediaPipe issue for that one: https://github.com/google-ai-edge/mediapipe. I think for the original .. AET is also converting properly as we don't want to change someone's specified input shape (seen in sample_input_2). Media pipe is allowed to specify a specific input shape as well ... so I don't think that's a problem either. So I would say you actually "did the right thing" in using a custom model to satisfy your needs.

RubensZimbres commented 3 months ago

Thanks, @pkgoogle , my idea was in fact insert a customized model (not supported) into MediaPipe. At my last update, the only problem was to add metadata. I followed https://www.tensorflow.org/lite/models/convert/metadata and added the following, given that my model num_classes was 1000, and not 1001, as in https://github.com/tensorflow/tflite-support/raw/master/tensorflow_lite_support/metadata/python/tests/testdata/image_classifier/labels.txt -o mobilenet_labels.txt:

input_stats.width = [224]
input_stats.height = [224]
input_stats.num_classes = [1000]

The only issue open is that from tflite_support import metadata_schema_py_generated as _metadata_fb from the metadata writer only runs successfuly on Tensorflow 2.13.0. But maybe this not regards ai-edge-torch directly. Please feel free to close this issue.

Thanks!

Screenshot from 2024-07-17 22-42-35

pkgoogle commented 3 months ago

Hi @RubensZimbres you may find better support for that particular issue here: https://github.com/tensorflow/tflite-support. As requested closing, thanks for your help!