Open Coder-HuangBH opened 6 months ago
Does apk work properly? Or same issue with apk also?
I got pretty much the same result. Your apk and the tflite models work locally (inference in python works well, the result of the transcription is correct). Would you share @vilassn the python packages configuration that you used in the apk? (python3.8, transformers, tensorflow, numpy, etc.)
Thanks!
2024-04-29 09:48:50.970 24207-24753 Whisper com.whispertflite E Error... java.lang.IllegalArgumentException: Internal error: Failed to run on the given Interpreter: tensorflow/lite/kernels/reduce.cc:390 std::apply(optimized_ops::Mean<T, U>, args) was not true. tensorflow/lite/kernels/reduce.cc:390 std::apply(optimized_ops::Mean<T, U>, args) was not true. tensorflow/lite/kernels/reduce.cc:390 std::apply(optimized_ops::Mean<T, U>, args) was not true. tensorflow/lite/kernels/reduce.cc:390 std::apply(optimized_ops::Mean<T, U>, args) was not true. tensorflow/lite/kernels/reduce.cc:390 std::apply(optimized_ops::Mean<T, at org.tensorflow.lite.NativeInterpreterWrapper.run(Native Method) at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:247) at org.tensorflow.lite.InterpreterImpl.runForMultipleInputsOutputs(InterpreterImpl.java:107) at org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:80) at org.tensorflow.lite.InterpreterImpl.run(InterpreterImpl.java:100) at org.tensorflow.lite.Interpreter.run(Interpreter.java:80) at com.whispertflite.engine.WhisperEngine.runInference(WhisperEngine.java:147) at com.whispertflite.engine.WhisperEngine.transcribeFile(WhisperEngine.java:74) at com.whispertflite.asr.Whisper.threadFunction(Whisper.java:129) at com.whispertflite.asr.Whisper.lambda$start$0$com-whispertflite-asr-Whisper(Whisper.java:76) at com.whispertflite.asr.Whisper$$ExternalSyntheticLambda0.run(Unknown Source:2) at java.lang.Thread.run(Thread.java:930)
The only difference between success and failure is the tflite file,This is their parameter print : success Input Tensor Dump ===> 2024-04-29 09:50:13.724 24920-25027 WhisperEngineJava com.whispertflite D shape.length: 3 2024-04-29 09:50:13.725 24920-25027 WhisperEngineJava com.whispertflite D shape[0]: 1 2024-04-29 09:50:13.725 24920-25027 WhisperEngineJava com.whispertflite D shape[1]: 80 2024-04-29 09:50:13.725 24920-25027 WhisperEngineJava com.whispertflite D shape[2]: 3000 2024-04-29 09:50:13.725 24920-25027 WhisperEngineJava com.whispertflite D dataType: FLOAT32 2024-04-29 09:50:13.726 24920-25027 WhisperEngineJava com.whispertflite D name: serving_default_input_ids:0 2024-04-29 09:50:13.726 24920-25027 WhisperEngineJava com.whispertflite D numBytes: 960000 2024-04-29 09:50:13.726 24920-25027 WhisperEngineJava com.whispertflite D index: 0 2024-04-29 09:50:13.726 24920-25027 WhisperEngineJava com.whispertflite D numDimensions: 3 2024-04-29 09:50:13.727 24920-25027 WhisperEngineJava com.whispertflite D numElements: 240000 2024-04-29 09:50:13.727 24920-25027 WhisperEngineJava com.whispertflite D shapeSignature.length: 3 2024-04-29 09:50:13.727 24920-25027 WhisperEngineJava com.whispertflite D quantizationParams.getScale: 0.0 2024-04-29 09:50:13.727 24920-25027 WhisperEngineJava com.whispertflite D quantizationParams.getZeroPoint: 0 2024-04-29 09:50:13.727 24920-25027 WhisperEngineJava com.whispertflite D ================================================================== 2024-04-29 09:50:13.728 24920-25027 WhisperEngineJava com.whispertflite D Output Tensor Dump ===> 2024-04-29 09:50:13.728 24920-25027 WhisperEngineJava com.whispertflite D shape.length: 2 2024-04-29 09:50:13.728 24920-25027 WhisperEngineJava com.whispertflite D shape[0]: 1 2024-04-29 09:50:13.728 24920-25027 WhisperEngineJava com.whispertflite D shape[1]: 448 2024-04-29 09:50:13.729 24920-25027 WhisperEngineJava com.whispertflite D dataType: INT32 2024-04-29 09:50:13.729 24920-25027 WhisperEngineJava com.whispertflite D name: StatefulPartitionedCall:0 2024-04-29 09:50:13.729 24920-25027 WhisperEngineJava com.whispertflite D numBytes: 1792 2024-04-29 09:50:13.729 24920-25027 WhisperEngineJava com.whispertflite D index: 1047 2024-04-29 09:50:13.729 24920-25027 WhisperEngineJava com.whispertflite D numDimensions: 2 2024-04-29 09:50:13.729 24920-25027 WhisperEngineJava com.whispertflite D numElements: 448 2024-04-29 09:50:13.730 24920-25027 WhisperEngineJava com.whispertflite D shapeSignature.length: 2 2024-04-29 09:50:13.730 24920-25027 WhisperEngineJava com.whispertflite D quantizationParams.getScale: 0.0 2024-04-29 09:50:13.730 24920-25027 WhisperEngineJava com.whispertflite D quantizationParams.getZeroPoint: 0 2024-04-29 09:50:13.730 24920-25027 WhisperEngineJava com.whispertflite D ==================================================================
failed Input Tensor Dump ===> 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shape.length: 3 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shape[0]: 1 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shape[1]: 80 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shape[2]: 3000 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D dataType: FLOAT32 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D name: serving_default_input_ids:0 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D numBytes: 960000 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D index: 0 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D numDimensions: 3 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D numElements: 240000 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shapeSignature.length: 3 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D quantizationParams.getScale: 0.0 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D quantizationParams.getZeroPoint: 0 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D ================================================================== 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D Output Tensor Dump ===> 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shape.length: 2 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shape[0]: 1 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D shape[1]: 451 2024-04-29 09:48:45.685 24207-24753 WhisperEngineJava com.whispertflite D dataType: INT32 2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D name: StatefulPartitionedCall:0 2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D numBytes: 1804 2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D index: 559 2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D numDimensions: 2 2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D numElements: 451 2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D shapeSignature.length: 2 2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D quantizationParams.getScale: 0.0 2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D quantizationParams.getZeroPoint: 0 2024-04-29 09:48:45.686 24207-24753 WhisperEngineJava com.whispertflite D ==================================================================
The script for generating the failed tflite file is as follows:
import tensorflow as tf from datasets import load_dataset from transformers import WhisperProcessor, WhisperFeatureExtractor, TFWhisperForConditionalGeneration, WhisperTokenizer
whisperPath = "openai/whisper-tiny.en" saved_model_dir = 'path/to/tf_whisper_saved'
tflite_model_path = 'path/to/whisper111.tflite'
feature_extractor = WhisperFeatureExtractor.from_pretrained(whisperPath) tokenizer = WhisperTokenizer.from_pretrained(whisperPath, predict_timestamps=True) processor = WhisperProcessor(feature_extractor, tokenizer) model = TFWhisperForConditionalGeneration.from_pretrained(whisperPath, from_pt=True)
Loading dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code = True)
inputs = feature_extractor( ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="tf" ) input_features = inputs.input_features
Generating Transcription
generated_ids = model.generate(input_features=input_features) print(generated_ids) transcription = processor.tokenizer.decode(generated_ids[0]) print(transcription) model.save(saved_model_dir)
Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops. ] converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_model = converter.convert()
Save the model
with open(tflite_model_path, 'wb') as f: f.write(tflite_model)
class GenerateModel(tf.Module): def init(self, model): super(GenerateModel, self).init() self.model = model
@tf.function(
shouldn't need static batch size, but throws exception without it (needs to be fixed)
) def serving(self, input_features): outputs = self.model.generate( input_features, max_new_tokens=450, #change as needed return_dict_in_generate=True, ) return {"sequences": outputs["sequences"]}
saved_model_dir = '/content/tf_whisper_saved'
tflite_model_path = 'whisper-tiny.en.tflite'
tflite_model_path = 'path/to/whisper222.tflite'
tflite_model_path = 'path/to/whisper_vi222.tflite'
generate_model = GenerateModel(model=model) tf.saved_model.save(generate_model, saved_model_dir, signatures={"serving_default": generate_model.serving})
Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops. ] converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_model = converter.convert()
Save the model
with open(tflite_model_path, 'wb') as f: f.write(tflite_model)