nyadla-sys / whisper.tflite

Optimized OpenAI's Whisper TFLite Port for Efficient Offline Inference on Edge Devices
MIT License
134 stars 29 forks source link

Is it possible to set input language with whisper-base.tflite with this code #29

Closed SchweitzerGAO closed 3 months ago

SchweitzerGAO commented 6 months ago

Here's the thing, I speak Chinese but the result is in English with similar meaning. So I am wondering if it is possible to set the input language so that the output will be in the same language as the input.

nyadla-sys commented 6 months ago

Need to generate specific model depends on the language selection

SchweitzerGAO commented 6 months ago

Update: The model works properly but no matter how I set the forced_decoder_ids , the result is in English with the task set translate. Code:

from timeit import default_timer as timer

import numpy as np
import tensorflow as tf
import whisper
from datasets import load_dataset
from transformers import WhisperProcessor, TFWhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizer

forced_decoder_ids_outer = [(1, 50266), (2, 50359), (3, 50363)]

def save_tf_model(model):
    # Creating force_token_map to be used in GenerationConfig
    # force_token_map = [[50258, 50266], [50359, 50363]]
    #
    # # Creating generation_config with force_token_map
    # generation_config = GenerationConfig(force_token_map=force_token_map)

    # Creating an instance of AutoProcessor from the pretrained model
    processor = WhisperProcessor.from_pretrained("./whisper-base")

    # Creating an instance of TFWhisperForConditionalGeneration from the pretrained model

    # Loading dataset
    ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

    # Inputs
    inputs = processor(ds[0]["audio"]["array"], return_tensors="tf")
    input_features = inputs.input_features
    model.config.forced_decoder_ids = forced_decoder_ids_outer
    # Generating Transcription
    generated_ids = model.generate(input_ids=input_features, forced_decoder_ids=forced_decoder_ids_outer)
    transcription = processor.batch_decode(generated_ids)[0]
    print(transcription)
    model.save('./content/tf_whisper_saved')

def convert_tflite(model):
    # Saving the model
    saved_model_dir = './content/tf_whisper_lite'
    generate_model = GenerateModel(model=model)
    tf.saved_model.save(generate_model, saved_model_dir,
                        signatures={"serving_default": generate_model.serving})

    # Converting to TFLite model
    tflite_model_path = './whisper-base-2.tflite'
    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
    # converter.target_spec.supported_ops = [
    #     tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
    #     tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
    # ]
    # converter.inference_input_type = tf.int8  # or tf.uint8
    # converter.inference_output_type = tf.int8  # or tf.uint8
    # converter.experimental_enable_resource_variables = True
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    tflite_model = converter.convert()

    # Saving the TFLite model
    with open(tflite_model_path, 'wb') as f:
        f.write(tflite_model)

class GenerateModel(tf.Module):
    def __init__(self, model):
        super(GenerateModel, self).__init__()
        self.model = model
        self.forced_decoder_ids = [[50258, 50266], [50359, 50363]]
        # self.lang_dict = {
        #     tf.constant(50259).ref(): 50259,
        #     tf.constant(50260).ref(): 50260,
        #     tf.constant(50261).ref(): 50261,
        #     tf.constant(50262).ref(): 50262,
        #     tf.constant(50263).ref(): 50263,
        #     tf.constant(50264).ref(): 50264,
        #     tf.constant(50265).ref(): 50265,
        #     tf.constant(50266).ref(): 50266,
        #     tf.constant(50267).ref(): 50267,
        #     tf.constant(50268).ref(): 50268,
        #     tf.constant(50272).ref(): 50272,
        #     tf.constant(50274).ref(): 50274,
        #     tf.constant(50290).ref(): 50290,
        #     tf.constant(50300).ref(): 50300,
        #     tf.constant(50289).ref(): 50289,
        #     tf.constant(50275).ref(): 50275
        # }

    @tf.function(
        # shouldn't need static batch size, but throws exception without it (needs to be fixed)
        input_signature=[
            tf.TensorSpec((1, 80, 3000), tf.float32, name="input_ids"),
            # tf.TensorSpec((), tf.int32, name="lang")
        ],
    )
    def serving(self, input_ids):
        self.model.config.forced_decoder_ids = self.forced_decoder_ids
        outputs = self.model.generate(
            input_ids=input_ids,
            forced_decoder_ids=self.forced_decoder_ids,
            max_new_tokens=223
        )
        return {"sequences": outputs}

def test():
    tflite_model_path = './whisper-base-2.tflite'
    feature_extractor = WhisperFeatureExtractor.from_pretrained("../whisper-base")
    tokenizer = WhisperTokenizer.from_pretrained("../whisper-base", predict_timestamps=True)
    processor = WhisperProcessor(feature_extractor, tokenizer)
    ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

    inputs = processor(
        ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="tf"
    )
    input_features = inputs.input_features

    interpreter = tf.lite.Interpreter(tflite_model_path)

    tflite_generate = interpreter.get_signature_runner()
    generated_ids = tflite_generate(input_ids=input_features, lang=tf.constant(50259))["sequences"]
    transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    print(transcription)

def test_2():
    tflite_model_path = './whisper-base-2.tflite'

    # Create an interpreter to run the TFLite model
    interpreter = tf.lite.Interpreter(tflite_model_path)

    # Allocate memory for the interpreter
    interpreter.allocate_tensors()

    # Get the input and output tensors
    input_tensor_0 = interpreter.get_input_details()[0]['index']
    # input_tensor_1 = interpreter.get_input_details()[1]['index']

    output_tensor = interpreter.get_output_details()[0]['index']

    inference_start = timer()

    # Calculate the mel spectrogram of the audio file
    print(f'Calculating mel spectrogram...')
    mel_from_file = whisper.audio.log_mel_spectrogram('./2.wav')

    # Pad or trim the input data to match the expected input size
    input_data = whisper.audio.pad_or_trim(mel_from_file, whisper.audio.N_FRAMES)

    # Add a batch dimension to the input data
    input_data = np.expand_dims(input_data, 0)

    # Run the TFLite model using the interpreter
    print("Invoking interpreter ...")
    interpreter.set_tensor(input_tensor_0, input_data)
    # interpreter.set_tensor(input_tensor_1, tf.constant(50259))
    interpreter.invoke()

    # Get the output data from the interpreter
    output_data = interpreter.get_tensor(output_tensor)

    # Print the output data
    # print(output_data)

    # Create a tokenizer to convert tokens to text
    wtokenizer = whisper.tokenizer.get_tokenizer(True, language="ja")

    # convert tokens to text
    print("Converting tokens ...")
    for token in output_data:
        # Replace -100 with the end of text token
        token[token == -100] = wtokenizer.eot
        text = wtokenizer.decode(token)
        print(text)

    print("\nInference took {:.2f}s ".format(timer() - inference_start))

if __name__ == '__main__':
    model = TFWhisperForConditionalGeneration.from_pretrained("./whisper-base")
    # # processor = WhisperProcessor.from_pretrained("./whisper-base")
    # # forced_decoder_ids_outer = processor.get_decoder_prompt_ids(language="zh", task="transcribe")
    # # print(forced_decoder_ids_outer)
    # #
    # save_tf_model(model)
    convert_tflite(model)
    test_2()

output result: <|startoftranscript|><|ja|><|translate|><|notimestamps|> The family of the Shiasu is waiting for him.<|endoftext|> But it os expected in Japanese. Any solutions?

SchweitzerGAO commented 6 months ago

Update: I discovered that when setting the language to English, the task token will be correctly input with <|transcribe|> but when it comes to other languages(e,g, Japanese) the task token will be <|translate|>, which I think is weird because I have set the task token 50359(transcribe) in both scenarios.

nyadla-sys commented 4 months ago

@SchweitzerGAO Have you discovered how to force task token to transcribe?

SchweitzerGAO commented 4 months ago

@nyadla-sys unfortunately, I gave up and eventually switch to onnx with separated encoder and decoder, with which I can set the task and language properly

nyadla-sys commented 4 months ago

@SchweitzerGAO I've noticed that running the decoder TFLite model tends to be quite time-consuming. Have there been any recent improvements in its performance? If so, could you please share a Colab notebook that demonstrates how to use and create the encoder and decoder TFLite models? Thank you.

SchweitzerGAO commented 4 months ago

@nyadla-sys Sorry, I haven't tried encoder-decoder separated model with tflite, I just tried onnx with an off-the-shelf notebook here. They also provide notebook to generate encoder-decoder separated model with tflite here. It will be rather convenient to refer to these notebooks, Thanks. BTW, it is also time-consuming to run onnx models on mobile devices and I have't yet figured out possible approaches to improving this.