vilassn / whisper_android

Offline Speech Recognition with OpenAI Whisper and TensorFlow Lite for Android
MIT License
189 stars 27 forks source link

IllegalArgumentException: Internal error: Failed to run on the given Interpreter #12

Open Coder-HuangBH opened 4 months ago

Coder-HuangBH commented 4 months ago

2024-04-29 09:48:50.970 24207-24753 Whisper com.whispertflite E Error... java.lang.IllegalArgumentException: Internal error: Failed to run on the given Interpreter: tensorflow/lite/kernels/reduce.cc:390 std::apply(optimized_ops::Mean<T, U>, args) was not true. tensorflow/lite/kernels/reduce.cc:390 std::apply(optimized_ops::Mean<T, U>, args) was not true. tensorflow/lite/kernels/reduce.cc:390 std::apply(optimized_ops::Mean<T, U>, args) was not true. tensorflow/lite/kernels/reduce.cc:390 std::apply(optimized_ops::Mean<T, U>, args) was not true. tensorflow/lite/kernels/reduce.cc:390 std::apply(optimized_ops::Mean<T, at org.tensorflow.lite.NativeInterpreterWrapper.run(Native Method) at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:247) at org.tensorflow.lite.InterpreterImpl.runForMultipleInputsOutputs(InterpreterImpl.java:107) at org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:80) at org.tensorflow.lite.InterpreterImpl.run(InterpreterImpl.java:100) at org.tensorflow.lite.Interpreter.run(Interpreter.java:80) at com.whispertflite.engine.WhisperEngine.runInference(WhisperEngine.java:147) at com.whispertflite.engine.WhisperEngine.transcribeFile(WhisperEngine.java:74) at com.whispertflite.asr.Whisper.threadFunction(Whisper.java:129) at com.whispertflite.asr.Whisper.lambda$start$0$com-whispertflite-asr-Whisper(Whisper.java:76) at com.whispertflite.asr.Whisper$$ExternalSyntheticLambda0.run(Unknown Source:2) at java.lang.Thread.run(Thread.java:930)

The only difference between success and failure is the tflite file,This is their parameter print : success Input Tensor Dump ===> 2024-04-29 09:50:13.724 24920-25027 WhisperEngineJava com.whispertflite D shape.length: 3 2024-04-29 09:50:13.725 24920-25027 WhisperEngineJava com.whispertflite D shape[0]: 1 2024-04-29 09:50:13.725 24920-25027 WhisperEngineJava com.whispertflite D shape[1]: 80 2024-04-29 09:50:13.725 24920-25027 WhisperEngineJava com.whispertflite D shape[2]: 3000 2024-04-29 09:50:13.725 24920-25027 WhisperEngineJava com.whispertflite D dataType: FLOAT32 2024-04-29 09:50:13.726 24920-25027 WhisperEngineJava com.whispertflite D name: serving_default_input_ids:0 2024-04-29 09:50:13.726 24920-25027 WhisperEngineJava com.whispertflite D numBytes: 960000 2024-04-29 09:50:13.726 24920-25027 WhisperEngineJava com.whispertflite D index: 0 2024-04-29 09:50:13.726 24920-25027 WhisperEngineJava com.whispertflite D numDimensions: 3 2024-04-29 09:50:13.727 24920-25027 WhisperEngineJava com.whispertflite D numElements: 240000 2024-04-29 09:50:13.727 24920-25027 WhisperEngineJava com.whispertflite D shapeSignature.length: 3 2024-04-29 09:50:13.727 24920-25027 WhisperEngineJava com.whispertflite D quantizationParams.getScale: 0.0 2024-04-29 09:50:13.727 24920-25027 WhisperEngineJava com.whispertflite D quantizationParams.getZeroPoint: 0 2024-04-29 09:50:13.727 24920-25027 WhisperEngineJava com.whispertflite D ================================================================== 2024-04-29 09:50:13.728 24920-25027 WhisperEngineJava com.whispertflite D Output Tensor Dump ===> 2024-04-29 09:50:13.728 24920-25027 WhisperEngineJava com.whispertflite D shape.length: 2 2024-04-29 09:50:13.728 24920-25027 WhisperEngineJava com.whispertflite D shape[0]: 1 2024-04-29 09:50:13.728 24920-25027 WhisperEngineJava com.whispertflite D shape[1]: 448 2024-04-29 09:50:13.729 24920-25027 WhisperEngineJava com.whispertflite D dataType: INT32 2024-04-29 09:50:13.729 24920-25027 WhisperEngineJava com.whispertflite D name: StatefulPartitionedCall:0 2024-04-29 09:50:13.729 24920-25027 WhisperEngineJava com.whispertflite D numBytes: 1792 2024-04-29 09:50:13.729 24920-25027 WhisperEngineJava com.whispertflite D index: 1047 2024-04-29 09:50:13.729 24920-25027 WhisperEngineJava com.whispertflite D numDimensions: 2 2024-04-29 09:50:13.729 24920-25027 WhisperEngineJava com.whispertflite D numElements: 448 2024-04-29 09:50:13.730 24920-25027 WhisperEngineJava com.whispertflite D shapeSignature.length: 2 2024-04-29 09:50:13.730 24920-25027 WhisperEngineJava com.whispertflite D quantizationParams.getScale: 0.0 2024-04-29 09:50:13.730 24920-25027 WhisperEngineJava com.whispertflite D quantizationParams.getZeroPoint: 0 2024-04-29 09:50:13.730 24920-25027 WhisperEngineJava com.whispertflite D ==================================================================

failed Input Tensor Dump ===> 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shape.length: 3 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shape[0]: 1 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shape[1]: 80 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shape[2]: 3000 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D dataType: FLOAT32 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D name: serving_default_input_ids:0 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D numBytes: 960000 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D index: 0 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D numDimensions: 3 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D numElements: 240000 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shapeSignature.length: 3 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D quantizationParams.getScale: 0.0 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D quantizationParams.getZeroPoint: 0 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D ================================================================== 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D Output Tensor Dump ===> 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shape.length: 2 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shape[0]: 1 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shape[1]: 451 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D dataType: INT32 2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D name: StatefulPartitionedCall:0 2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D numBytes: 1804 2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D index: 559 2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D numDimensions: 2 2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D numElements: 451 2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D shapeSignature.length: 2 2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D quantizationParams.getScale: 0.0 2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D quantizationParams.getZeroPoint: 0 2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D ==================================================================

The script for generating the failed tflite file is as follows:

import tensorflow as tf from datasets import load_dataset from transformers import WhisperProcessor, WhisperFeatureExtractor, TFWhisperForConditionalGeneration, WhisperTokenizer

whisperPath = "openai/whisper-tiny.en" saved_model_dir = 'path/to/tf_whisper_saved'

tflite_model_path = 'path/to/whisper111.tflite'

feature_extractor = WhisperFeatureExtractor.from_pretrained(whisperPath) tokenizer = WhisperTokenizer.from_pretrained(whisperPath, predict_timestamps=True) processor = WhisperProcessor(feature_extractor, tokenizer) model = TFWhisperForConditionalGeneration.from_pretrained(whisperPath, from_pt=True)

Loading dataset

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code = True)

inputs = feature_extractor( ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="tf" ) input_features = inputs.input_features

Generating Transcription

generated_ids = model.generate(input_features=input_features) print(generated_ids) transcription = processor.tokenizer.decode(generated_ids[0]) print(transcription) model.save(saved_model_dir)

Convert the model

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops. ] converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_model = converter.convert()

Save the model

with open(tflite_model_path, 'wb') as f: f.write(tflite_model)

class GenerateModel(tf.Module): def init(self, model): super(GenerateModel, self).init() self.model = model

@tf.function(

shouldn't need static batch size, but throws exception without it (needs to be fixed)

input_signature=[
  tf.TensorSpec((1, 80, 3000), tf.float32, name="input_ids"),
],

) def serving(self, input_features): outputs = self.model.generate( input_features, max_new_tokens=450, #change as needed return_dict_in_generate=True, ) return {"sequences": outputs["sequences"]}

saved_model_dir = '/content/tf_whisper_saved'

tflite_model_path = 'whisper-tiny.en.tflite'

tflite_model_path = 'path/to/whisper222.tflite'

tflite_model_path = 'path/to/whisper_vi222.tflite'

generate_model = GenerateModel(model=model) tf.saved_model.save(generate_model, saved_model_dir, signatures={"serving_default": generate_model.serving})

Convert the model

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops. ] converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_model = converter.convert()

Save the model

with open(tflite_model_path, 'wb') as f: f.write(tflite_model)

vilassn commented 4 months ago

Does apk work properly? Or same issue with apk also?

robre22 commented 2 months ago

I got pretty much the same result. Your apk and the tflite models work locally (inference in python works well, the result of the transcription is correct). Would you share @vilassn the python packages configuration that you used in the apk? (python3.8, transformers, tensorflow, numpy, etc.)

Thanks!