AlexanderLutsenko / nobuco

Pytorch to Keras/Tensorflow/TFLite conversion made intuitive
MIT License
253 stars 16 forks source link

Force static shapes through the graph? #48

Closed tonystratum closed 2 months ago

tonystratum commented 3 months ago

Trying to convert a model to tflite for the Coral TPU, which only supports static shapes. Is there any way to force static shapes throughout the whole graph?

P.S. Magnificent work! Thanks!

AlexanderLutsenko commented 3 months ago

Hi, thanks! Strictly speaking, TFLite only supports static shapes. I never tried Coral, but I think it should be fine.

tonystratum commented 3 months ago

@AlexanderLutsenko when I inspected the converted model with the https://github.com/google-ai-edge/model-explorer , it indicated shapes of size -1, which I presume are dynamic. Am I doing something wrong?

AlexanderLutsenko commented 3 months ago

Was it input tensor that had size -1? That should never occur unless you annotate some dimensions as dynamic in input_shapes argument. As for intermediate/output tensors, the result of certain ops has to be dynamic. For example:

class FilterBoxes(nn.Module):
    def __init__(self):
        super().__init__()
        self.threshold = 0.1

    def forward(self, boxes, scores):
        mask = scores > self.threshold
        return boxes[mask], scores[mask]

boxes = torch.normal(0, 1, size=(128, 4))
scores = torch.normal(0, 1, size=(128,))
pytorch_module = FilterBoxes().eval()

keras_model = nobuco.pytorch_to_keras(
    pytorch_module,
    args=[boxes, scores],
)

"1" in tensor shapes should be "-1", Netron visualizes it wrong

Obviously, we do not know in advance how many scores will be greater that threshold. To avoid dynamic shapes, gt is usually replaced with topk:

    def forward(self, boxes, scores):
        top_scores, top_indices = torch.topk(scores, k=10, dim=-1)
        top_boxes = boxes[top_indices]
        return top_boxes, top_scores

That's about how much I can tell without looking at your model. If that's not enough, feel free to share the part of it that causes problem.

johan-sightic commented 3 months ago

Hi! I have a similar problem but it may just be a limitation of tensorflow or tflite.

I have a 2-stage CNN in pytorch where the first stage uses a scaled down image to predict the center of a crop. The second stage then takes the crop in full resolution as input to make a prediction. Both the images and crops are of fixed size (640x480 and 128x128)

I convert the model from pytorch to tflite and want to run it on an Android phone GPU. Everything works fine on my computer, I get identical output from pytorch and tflite. To make it run on Android GPU I had to make some changes to the cropping, and it now looks a bit hacky but the tf.lite.experimental.Analyzer.analyze() function says it is GPU comparable. However, when I try to run it on Android GPU it crashes and says that the size of the crop is dynamic and it only supports static-sized tensors.

I can see why it thinks the size is dynamic but it really isn't, can I convince it that is is static?

Simplified example code ```python """ Convert CNN from pytorch to tflight. Requirements: - tensorflow-cpu==2.15.0 """ import numpy as np import torch from torch import nn import nobuco import tensorflow as tf from nobuco import ChannelOrder class MyModel(nn.Module): def __init__(self): super().__init__() self.stage_1_conv = nn.Conv2d(3, 4, kernel_size=8, stride=8) self.stage_1_relu = nn.ReLU() self.stage_1_linear = nn.Linear(4 * 20 * 15, 2) self.stage_1_sigmoid = nn.Sigmoid() self.stage_2_conv = nn.Conv2d(3, 4, kernel_size=8, stride=8) self.stage_2_relu = nn.ReLU() self.stage_2_linear = nn.Linear(4 * 16 * 16, 5) self.stage_2_tanh = nn.Tanh() @torch.no_grad() def forward(self, image: torch.Tensor) -> torch.Tensor: # Cast and normalize image (not GPU compatible, but that is fine) image = image.float() / 255 ##### Predict crop center (stage 1) ##### x = self.stage_1_conv(image[:, :, ::4, ::4]) x = self.stage_1_relu(x.flatten(start_dim=1)) x = self.stage_1_linear(x) rel_crop_center = self.stage_1_sigmoid(x) ##### Extract crop ##### image_size = torch.tensor((640, 480), dtype=torch.int32) half_crop_size = torch.tensor((64, 64), dtype=torch.int32) crop_center = (rel_crop_center * image_size).to(torch.int32) # Crop has to be within frame crop_center = torch.clamp(crop_center, half_crop_size, image_size - half_crop_size) # Hacky way to crop since TFLite GPU only supports slicing in 4 dimensions # and the only way I could make scalar tensor was with `min()` tl = crop_center - half_crop_size br = crop_center + half_crop_size top = tl.view(2, 1, 1, 1)[0:1, :, :, :].min() left = tl.view(2, 1, 1, 1)[1:2, :, :, :].min() bottom = br.view(2, 1, 1, 1)[0:1, :, :, :].min() right = br.view(2, 1, 1, 1)[1:2, :, :, :].min() crop = image[:, :, top:bottom, left:right] ##### Predict full resolution crop (stage 2) ##### x = self.stage_2_conv(crop) x = self.stage_2_relu(x.flatten(start_dim=1)) out = self.stage_2_tanh(self.stage_2_linear(x)) return out # Test pytorch model print("=" * 20, "Pytorch model", "=" * 20) pytorch_module = MyModel().eval() dummy_input = torch.randint(0, 256, size=(1, 3, 640, 480), dtype=torch.uint8) dummy_output = pytorch_module(dummy_input) print("Input: ", dummy_input.shape, dummy_input.dtype) print("Output: ", dummy_output.shape, dummy_output.dtype) # Convert the model with nobuco to keras. print("=" * 20, "Nobuco conversion", "=" * 20) keras_model = nobuco.pytorch_to_keras( pytorch_module, args=[dummy_input], inputs_channel_order=ChannelOrder.TENSORFLOW, outputs_channel_order=ChannelOrder.TENSORFLOW, save_trace_html=True, ) # Convert the model to tflite. print("=" * 20, "TFLite conversion", "=" * 20) converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) tflite_model = converter.convert() # Save the model. with open("mymodel.tflite", "wb") as f: f.write(tflite_model) # Analyze model print("=" * 20, "Model analysis", "=" * 20) tf.lite.experimental.Analyzer.analyze(model_content=tflite_model, gpu_compatibility=True) # Test run tflite model print("=" * 20, "TFLite model test", "=" * 20) interpreter = tf.lite.Interpreter("mymodel.tflite") my_signature = interpreter.get_signature_runner() output = my_signature(args_0=tf.constant(dummy_input.permute(0, 2, 3, 1).numpy())) pred = list(output.values())[0] assert np.allclose(np.array(pred), np.array(dummy_output)), "Model output mismatch!" ```
AlexanderLutsenko commented 3 months ago

@johan-sightic Well, that's unfortunate. But I see why TFLite fails to recognize the crop shape as static. Consider this:

y0 = center_y - h
y1 = center_y + h
crop = image[:, y0: y1]

As center_y and center_y are tensors, so are y0 and y1. When you invoke image[:, y0: y1], the indexing operator ([]) has no way of knowing that y0 and y1 are related and y1 - y0 == const. It wouldn't be a problem if the op accepted (center_y, h) as cropping parameters instead of (y0, y1). If you are to assemble such an op from existing Tensorflow primitives, brace yourself for some real jank:

@nobuco.traceable
def get_crop(x, center_y, center_x, h, w):
    return x[:, :, center_y - h:center_y + h, center_x - w:center_x + w]

@nobuco.converter(get_crop, channel_ordering_strategy=nobuco.ChannelOrderingStrategy.FORCE_TENSORFLOW_ORDER)
def converter_get_crop(x, center_y, center_x, h, w):
    def func(x, center_y, center_x, h, w):
        x = tf.roll(input=x, shift=-(center_y - h), axis=1)
        x = tf.roll(input=x, shift=-(center_x - w), axis=2)
        return x[:, :h*2, :w*2]
    return func
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.stage_1_conv = nn.Conv2d(3, 4, kernel_size=8, stride=8)
        self.stage_1_relu = nn.ReLU()
        self.stage_1_linear = nn.Linear(4 * 20 * 15, 2)
        self.stage_1_sigmoid = nn.Sigmoid()

        self.stage_2_conv = nn.Conv2d(3, 4, kernel_size=8, stride=8)
        self.stage_2_relu = nn.ReLU()
        self.stage_2_linear = nn.Linear(4 * 16 * 16, 5)
        self.stage_2_tanh = nn.Tanh()

    @torch.no_grad()
    def forward(self, image: torch.Tensor) -> torch.Tensor:
        # Cast and normalize image (not GPU compatible, but that is fine)
        image = image.float() / 255

        ##### Predict crop center (stage 1) #####
        x = self.stage_1_conv(image[:, :, ::4, ::4])
        x = self.stage_1_relu(x.flatten(start_dim=1))
        x = self.stage_1_linear(x)
        rel_crop_center = self.stage_1_sigmoid(x)

        ##### Extract crop #####
        image_size = torch.tensor((640, 480), dtype=torch.int32)
        half_crop_size = torch.tensor((64, 64), dtype=torch.int32)
        crop_center = (rel_crop_center * image_size).to(torch.int32)

        # Crop has to be within frame
        crop_center = torch.clamp(crop_center, half_crop_size, image_size - half_crop_size)
        center_y, center_x = crop_center[0]

        crop = get_crop(image, center_y, center_x, 64, 64)

        x = self.stage_2_conv(crop)
        x = self.stage_2_relu(x.flatten(start_dim=1))
        out = self.stage_2_tanh(self.stage_2_linear(x))

        return out

One drawback of this approach, apart from being suboptimal, is that tf.roll is not included in TFLITE_BUILTINS. You'd need to download/compile TFLite binaries with extended operator set. Also, don't forget to enable Flex ops in TFLiteConverter:

converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()

Tested on an Android device with OpenCL and XNNPACK delegates, so it might work for you, too.

johan-sightic commented 2 months ago

I see, thank you very muck for such a good answer, I will do some tests to see if this is a valid approach for us!

AlexanderLutsenko commented 2 months ago

@johan-sightic Hey, I just found a much better solution:

@nobuco.traceable
def get_crop(x, center_y, center_x, h, w):
    return x[:, :, center_y - h:center_y + h, center_x - w:center_x + w]

@nobuco.converter(get_crop, channel_ordering_strategy=nobuco.ChannelOrderingStrategy.FORCE_TENSORFLOW_ORDER)
def converter_get_crop(x, center_y, center_x, h, w):
    def func(x, center_y, center_x, h, w):
        return tf.image.crop_to_bounding_box(x, center_y - h, center_x - w, h*2, w*2)
    return func

It translates to normal tf.slice, so no custom binaries required.

johan-sightic commented 2 months ago

@johan-sightic Hey, I just found a much better solution:

@nobuco.traceable
def get_crop(x, center_y, center_x, h, w):
    return x[:, :, center_y - h:center_y + h, center_x - w:center_x + w]

@nobuco.converter(get_crop, channel_ordering_strategy=nobuco.ChannelOrderingStrategy.FORCE_TENSORFLOW_ORDER)
def converter_get_crop(x, center_y, center_x, h, w):
    def func(x, center_y, center_x, h, w):
        return tf.image.crop_to_bounding_box(x, center_y - h, center_x - w, h*2, w*2)
    return func

It translates to normal tf.slice, so no custom binaries required.

Thank you very much! This seems to work! I'm not sure how much of my model is actually run on cpu vs gpu because I don't quite understand the log output but at least it runs about 75% faster now :smile:

AlexanderLutsenko commented 2 months ago

@tonystratum I'm closing the issue for now due to inactivity. Please reopen if it's still relevant.