google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
378 stars 51 forks source link

Torch `unsqueeze` / `reshape` dimensions not preserved upon TFLite conversion #288

Open gudgud96 opened 1 month ago

gudgud96 commented 1 month ago

Description of the bug:

I am working on making my model to be GPU-compatible on TFLite, and one of the issues to solve is to bypass some operators that do not support broadcasting on GPU. One solution is to keep the tensors in the same rank through reshape or unsqueeze, but I find that the expanded dims are not preserved after TFLite conversion.

Suppose I have a minimal code snippet as below:

class TestModel(nn.Module):
    def __init__(self):
        super(TestModel, self).__init__()
        self.linear = nn.Linear(8, 8)
        self.linear2 = nn.Linear(32, 32)
    def forward(self, x):
        """
        x - (1, 20, 32)
        """
        B, T, D = x.shape
        tmp = x
        x = x.reshape(4, B * T, D // 4)                         # (4, 20, 8)
        x = self.linear(x)                                      # (4, 20, 8)
        x = x.transpose(0, 1).contiguous().reshape(1, T, D)     # (1, 20, 32)
        x = self.linear2(x)                                     # (1, 20, 32)                 

        # residual add
        x = tmp + x
        x = x.squeeze(0)
        return x

test_model = TestModel()
x = torch.rand(1, 20, 32)
with torch.no_grad():
    y = test_model(x)

import ai_edge_torch
edge_model = ai_edge_torch.convert(test_model, (x,))
edge_model.export("test_model.tflite")

At the residual add line, it is expected that tmp and x has the same shape, (1, 20, 32). However, after converting to TFLite, running

tf.lite.experimental.Analyzer.analyze(model_path="test_model.tflite", gpu_compatibility=True)

gives me incompatibility warnings, causing the model to not fully utilize GPU:

Subgraph#0 main(T#0) -> [T#14]
  Op#0 RESHAPE(T#0, T#7[4, 20, 8]) -> [T#8]
  Op#1 FULLY_CONNECTED(T#8, T#2, T#5) -> [T#9]
  Op#2 TRANSPOSE(T#9, T#4[1, 0, 2]) -> [T#10]
  Op#3 RESHAPE(T#10, T#3[20, 32]) -> [T#11]
  Op#4 FULLY_CONNECTED(T#11, T#1, T#6) -> [T#12]
  Op#5 ADD(T#0, T#12) -> [T#13]
GPU COMPATIBILITY WARNING: Doesn't support broadcasting - input0: [1,20,32], input1: [20,32]
  Op#6 RESHAPE(T#13, T#3[20, 32]) -> [T#14]
image

It seems like the conversion "squeezes" the Reshape op to output a 2-dim tensor, instead of a 3-dim tensor.

I would like to understand how the Reshape output shape is being optimized / changed during the conversion process. This would give a better understanding on how could I preserve the expanded dimensions, keep the Add tensors in the same shape, in order to bypass the Doesn't support broadcasting warnings.

Actual vs expected behavior:

Expected reshape dimensions to be preserved as Torch output tensor shapes after TFLite conversion. Actual TFLite converted output tensors might be squeezed due to some unknown optimizations.

Any other information you'd like to share?

No response

pkgoogle commented 1 month ago

I was able to replicate exactly as above.

gudgud96 commented 1 month ago

Hi @pkgoogle are there any updates regarding this issue?