Rishit-dagli / Conformer

An implementation of Conformer: Convolution-augmented Transformer for Speech Recognition, a Transformer Variant in TensorFlow/Keras
Apache License 2.0
42 stars 7 forks source link

tflite conversion fails. #17

Open ryol8888 opened 1 year ago

ryol8888 commented 1 year ago

Describe the bug

  File "/home/modelparser/Conformer/conformer_tf/conformer_tf.py", line 168, in call
    inputs = self.conv(inputs) + inputs
  File "/home/modelparser/Conformer/conformer_tf/conformer_tf.py", line 128, in call
    return self.net(inputs)
  File "/home/modelparser/Conformer/conformer_tf/conformer_tf.py", line 89, in call
    return tf.keras.layers.BatchNormalization(axis=-1)(inputs)
ValueError: Exception encountered when calling layer 'batch_norm' (type BatchNorm).

tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

Call arguments received by layer 'batch_norm' (type BatchNorm):
  • inputs=tf.Tensor(shape=(1, 1024, 1024), dtype=float32)

when i tried to convert .h5 to .tflite using convformer block, i got the above message. it caused by BatchNorm class. i fix BatchNorm class like this. and fixed it.

class BatchNorm(tf.keras.layers.Layer):
    def __init__(self, causal, **kwargs):
        super(BatchNorm, self).__init__(**kwargs)
        self.causal = causal
        self.bnorm = tf.keras.layers.BatchNormalization(axis=-1)
    def call(self, inputs):
        if not self.causal:
            return self.bnorm(inputs)
        return tf.identity(inputs)

To Reproduce I referred to tflite code in official site

def conformer():
    input_layer = tf.keras.layers.Input(shape=(1024, 512),  batch_size=1)

    conformer_block = ConformerBlock(
        dim = 512,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        conv_expansion_factor = 2,
        conv_kernel_size = 31,
        attn_dropout = 0.,
        ff_dropout = 0.,
        conv_dropout = 0.
    )(input_layer)

    return tf.keras.Model(inputs=input_layer, outputs=conformer_block)

def representative_dataset():
    for _ in range(100):
      data = np.random.rand(1, 1024, 512)
      yield [data.astype(np.float32)]   

net = conformer()
converter = tf.lite.TFLiteConverter.from_keras_model(net)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.representative_dataset = representative_dataset
tflite_model = converter.convert()
with open("conformer.tflite", "wb+") as tflite_file:
  tflite_file.write(tflite_model)

Desktop (please complete the following information):

Smartphone (please complete the following information):

Additional context