google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
310 stars 40 forks source link

LSTMs are being unrolled and decomposed after conversion #53

Open Doomski99 opened 3 months ago

Doomski99 commented 3 months ago

Description of the bug:

I'm submitting this bug by request from @pkgoogle where we found it in #62275 .

It seems that the converter is doing two unintended actions when handling LSTMs:

  1. LSTM is being decomposed instead of being converted into "UnidirectionalSequenceLSTM" operator. The latter is the default behavior in Tensorflow.
  2. The LSTM is being unrolled without the user's consent. In Tensorflow, one of its arguments allow for unrolling but by default it's off (obviously). Now, if the first bug is fixed, this bug might no longer be relevant unless the user wishes to manipulate the hidden states as that will force the compiler, at least in tensorflow, to switch to the decomposed operators defined inside a "While" loop as I described in #62775. In this case, it should be up to the user to choose to use a loop or the unrolled version.

Actual vs expected behavior:

You can find the test code below.

Actual Behavior: With ai-edge-torch, we can clearly see the LSTM is being decomposed and unrolled: image

Expected Behavior: With tensorflow: image

Any other information you'd like to share?

Torch code:

import torch
from torch import nn
import ai_edge_torch

class SimpleModel(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()

        self.lstm = nn.LSTM(input_size, hidden_size)
        self.d1 = nn.Linear(hidden_size, 1)

    def forward(self, x):

        x, (h0, c0) = self.lstm(x)
        x = self.d1(x)

        return x

model = SimpleModel(256, 64)
sample_inputs = (torch.randn(16, 43, 256),)

edge_model = ai_edge_torch.convert(model.eval(), sample_inputs)
edge_model.export("simple_model.tflite")

Tensorflow code:

import numpy as np
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, LSTM

model_name = "1x_LSTM_64_float32"
input_length = 256

class SimpleModel(Model):
  def __init__(self, input_shape, hidden_size):
    super().__init__()

    self.lstm = LSTM(hidden_size, return_sequences = True, return_state=True, input_shape = [-1, input_shape] )

    self.d1 = Dense(1, input_shape = [-1, hidden_size])

  def call(self, x):

    x, h0, c0 = self.lstm(x)
    x = self.d1(x)

    return x

model = SimpleModel(input_length, 64)

out, states = model(tf.random.uniform([16,43,256]))

print(np.mean(out))

model_path = f"{model_name}.tf"

run_model = tf.function(lambda x: model(x))
BATCH_SIZE = 16
STEPS = 43
INPUT_SIZE = 256
concrete_func = run_model.get_concrete_function(
    tf.TensorSpec([BATCH_SIZE, STEPS, INPUT_SIZE], tf.float32))

model.save(model_path, save_format = 'tf', signatures=concrete_func)

converter = tf.lite.TFLiteConverter.from_saved_model(model_path)
converter.target_spec.supported_ops = [
  tf.lite.OpsSet.TFLITE_BUILTINS,
]

tflite_model = converter.convert()
open(f"{model_name}.tflite", "wb").write(tflite_model)

Python version: 3.11.9 ai_edge_torch version: 0.1.1 (installed in a fresh conda environment by following the instructions mentioned in the release section. Operating system: Ubuntu 22.04.3 LTS in WSL2