nyadla-sys / whisper.tflite

Optimized OpenAI's Whisper TFLite Port for Efficient Offline Inference on Edge Devices
MIT License
134 stars 29 forks source link

Export distilled whisper to TFLite #26

Open mikel-brostrom opened 7 months ago

mikel-brostrom commented 7 months ago

Have you tried the TFLite export notebook on Distilled Whisper (https://github.com/huggingface/distil-whisper)? I modified the script like this:

import tensorflow as tf

from datasets import load_dataset
from transformers import WhisperProcessor, WhisperFeatureExtractor, TFWhisperForConditionalGeneration, WhisperTokenizer

feature_extractor = WhisperFeatureExtractor.from_pretrained("distil-whisper/distil-small.en")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small.en", predict_timestamps=False)
processor = WhisperProcessor(feature_extractor, tokenizer)
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-small.en")
# Loading dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

inputs = feature_extractor(
    ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="tf"
)
input_features = inputs.input_features

# Generating Transcription
generated_ids = model.generate(input_features=input_features)
print(generated_ids)
transcription = processor.tokenizer.decode(generated_ids[0])
print(transcription)
model.save('/content/tf_whisper_saved')

I am not sure if the same tokenizer as for regular whisper can be used... Anyways, this piece of code crashes:

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
  tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
  tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]

Any idea on how to get a successful distilled whisper export?

nyadla-sys commented 7 months ago

I have not tried converting Distilled Whisper to TFLite

mikel-brostrom commented 7 months ago

Just setting all the paths to distil-whisper/distil-small.en, like this:

import tensorflow as tf

from datasets import load_dataset
from transformers import WhisperProcessor, WhisperFeatureExtractor, TFWhisperForConditionalGeneration, WhisperTokenizer

feature_extractor = WhisperFeatureExtractor.from_pretrained("distil-whisper/distil-small.en")
tokenizer = WhisperTokenizer.from_pretrained("distil-whisper/distil-small.en", predict_timestamps=False)
processor = WhisperProcessor(feature_extractor, tokenizer)
model = TFWhisperForConditionalGeneration.from_pretrained("distil-whisper/distil-small.en")
# Loading dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

inputs = feature_extractor(
    ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="tf"
)
input_features = inputs.input_features

# Generating Transcription
generated_ids = model.generate(input_features=input_features)
print(generated_ids)
transcription = processor.tokenizer.decode(generated_ids[0])
print(transcription)
model.save('/content/tf_whisper_saved')

surprisingly works @nyadla-sys. We tried the same WAV file in the App on the regular whisper small and distilled. The differences in the execution times are minimal. 22 vs 21 seconds. Does anybody have a rationale behind this? Distiled whisper (small tflite) model size obtained: 170MB, regular whisper: 220MB

nyadla-sys commented 1 week ago

@mikel-brostrom Good work, Could you post generated Distiled whisper small tflite model via google drive link ? People can benefit with this model