nyadla-sys / whisper.tflite

Optimized OpenAI's Whisper TFLite Port for Efficient Offline Inference on Edge Devices
MIT License
177 stars 32 forks source link

How to change the name of whisper tflite nodes #40

Open zzy981019 opened 2 months ago

zzy981019 commented 2 months ago

Hi @nyadla-sys ,

I am using your generate_tflite_from_whisper.ipynb to generate whisper models in tflite, and when I open the converted model with netron, I find that the names of its input and output nodes are very long and not what I expected. I'd like to ask do you have any insights of changing the name of the input and the output nodes?

whisper_tflite_from_netron

Many thanks!

nyadla-sys commented 2 months ago

Try something like below

import tensorflow as tf

class GenerateModel(tf.Module):
    def __init__(self, model):
        super(GenerateModel, self).__init__()
        self.model = model

    @tf.function(
        input_signature=[
            tf.TensorSpec((1, 80, 3000), tf.float32, name="new_input_name"),  # Updated input name
        ],
    )
    def serving(self, new_input_name):  # Updated parameter name
        outputs = self.model.generate(
            new_input_name,
            max_new_tokens=450,  # Change as needed
            return_dict_in_generate=True,
        )
        return {"new_output_name": outputs["sequences"]}  # Updated output name

saved_model_dir = '/content/tf_whisper_saved'
tflite_model_path = 'whisper-tiny.en.tflite'

# Create and save the TensorFlow model with updated names
generate_model = GenerateModel(model=model)
tf.saved_model.save(generate_model, saved_model_dir, signatures={"serving_default": generate_model.serving})

# Convert the model to TensorFlow Lite
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # Enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS  # Enable TensorFlow ops.
]
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save the converted TFLite model
with open(tflite_model_path, 'wb') as f:
    f.write(tflite_model)
nyadla-sys commented 2 months ago

image

zzy981019 commented 2 months ago

Thanks! I have tried the code you attached, and I found the name of the input tensor has been updated (I attach the pic here). However, must the name of the input tensor the format of signatureKey + "_" + inputName + ":0"? Can it just by some means be set to "input_features"?

whisper_tflite_from_netron_new