huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.68k stars 26.22k forks source link

TF Lite model created from TFWhisperForConditionalGeneration.from_pretrained craches #32125

Open dsame opened 1 month ago

dsame commented 1 month ago

System Info

The main branch still does not include the fix: https://github.com/huggingface/transformers/issues/19691#issuecomment-1791869884

namely

    #new_scores = tf.ones_like(scores, dtype=scores.dtype) * -float("inf")
    new_scores = tf.ones_like(scores, dtype=scores.dtype) * -float(1)

which causes TFLite Whisper model to crash when used as:

    interpreter = tf.lite.Interpreter(tflite_model_path)
    tflite_generate = interpreter.get_signature_runner()
    output = tflite_generate(input_features=input_features)

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

class GenerateModel(tf.Module):
    def __init__(self, model, forced_decoder_ids):
        super(GenerateModel, self).__init__()
        self.model = model
        self.forced_decoder_ids = forced_decoder_ids

    @tf.function(
        # shouldn't need static batch size, but throws exception without it (needs to be fixed)
        input_signature=[
            tf.TensorSpec((1, 80, 3000), tf.float32, name="input_features"),
        ],
    )
    def serving(self, input_features):
        outputs = self.model.generate(
            input_features,
            forced_decoder_ids=self.forced_decoder_ids,
            #max_new_tokens=223,  # change as needed
            return_dict_in_generate=True,
        )
        return {"sequences": outputs["sequences"]}

model = TFWhisperForConditionalGeneration.from_pretrained(model_name)
generate_model = GenerateModel(model=model, forced_decoder_ids=forced_decoder_ids)
tf.saved_model.save(generate_model, saved_model_dir, signatures={"serving_default": generate_model.serving})

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
        tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
]

converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save the model
with open(tflite_model_path, 'wb') as f:
        f.write(tflite_model)

....
    interpreter = tf.lite.Interpreter(tflite_model_path)
    tflite_generate = interpreter.get_signature_runner()
    output = tflite_generate(input_features=input_features)

causes crache on tflite_generate.

Applying the patch https://github.com/huggingface/transformers/issues/19691#issuecomment-1791869884 solves the problem

Expected behavior

The code below should work without patch

class GenerateModel(tf.Module):
    def __init__(self, model, forced_decoder_ids):
        super(GenerateModel, self).__init__()
        self.model = model
        self.forced_decoder_ids = forced_decoder_ids

    @tf.function(
        # shouldn't need static batch size, but throws exception without it (needs to be fixed)
        input_signature=[
            tf.TensorSpec((1, 80, 3000), tf.float32, name="input_features"),
        ],
    )
    def serving(self, input_features):
        outputs = self.model.generate(
            input_features,
            forced_decoder_ids=self.forced_decoder_ids,
            #max_new_tokens=223,  # change as needed
            return_dict_in_generate=True,
        )
        return {"sequences": outputs["sequences"]}

model = TFWhisperForConditionalGeneration.from_pretrained(model_name)
generate_model = GenerateModel(model=model, forced_decoder_ids=forced_decoder_ids)
tf.saved_model.save(generate_model, saved_model_dir, signatures={"serving_default": generate_model.serving})

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
        tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
]

converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save the model
with open(tflite_model_path, 'wb') as f:
        f.write(tflite_model)

....
    interpreter = tf.lite.Interpreter(tflite_model_path)
    tflite_generate = interpreter.get_signature_runner()
    output = tflite_generate(input_features=input_features)
dsame commented 1 month ago

The colab notebook: https://colab.research.google.com/drive/1CIDh8wqZOS7ifkUQkEIxM5osEXAeYMof

amyeroberts commented 1 month ago

cc @Rocketknight1

dsame commented 1 month ago

The exact error:

2024-07-28 09:56:51.045333: I tensorflow/compiler/mlir/lite/flatbuffer_export.cc:3064] Estimated count of arithmetic ops: 76.370 G  ops, equivalently 38.185 G  MACs
Traceback (most recent call last):
  File "/home/dsa/whishper2tflite/whisper2tflite.py", line 96, in <module>
    test()
  File "/home/dsa/whishper2tflite/test.py", line 49, in test
    output = tflite_generate(input_features=input_features)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dsa/whishper2tflite/.venv/lib/python3.12/site-packages/tensorflow/lite/python/interpreter.py", line 249, in __call__
    self._interpreter_wrapper.Invoke(self._subgraph_index)
RuntimeError: tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.tensorflow/lite/kernels/reduce.cc:445 std::apply(optimized_ops::Mean<T, U>, args) was not true.gather index out of boundsNode number 34 (GATHER) failed to invoke.Node number 618 (WHILE) failed to invoke.

python3 --version Python 3.12.3

pip freeze

absl-py==2.1.0
astunparse==1.6.3
certifi==2024.7.4
cffi==1.16.0
charset-normalizer==3.3.2
coloredlogs==15.0.1
filelock==3.15.4
flatbuffers==24.3.25
fsspec==2024.6.1
gast==0.6.0
google-pasta==0.2.0
grpcio==1.65.1
h5py==3.11.0
huggingface-hub==0.24.2
humanfriendly==10.0
idna==3.7
Jinja2==3.1.4
keras==3.4.1
libclang==18.1.1
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
mdurl==0.1.2
ml-dtypes==0.3.2
mpmath==1.3.0
namex==0.0.8
networkx==3.3
numpy==1.26.4
onnx==1.16.1
onnxruntime==1.18.1
opt-einsum==3.3.0
optree==0.12.1
packaging==24.1
protobuf==4.25.4
pycparser==2.22
Pygments==2.18.0
PySoundFile==0.9.0.post1
PyYAML==6.0.1
regex==2024.7.24
requests==2.32.3
rich==13.7.1
safetensors==0.4.3
setuptools==71.1.0
six==1.16.0
sympy==1.13.1
tensorboard==2.16.2
tensorboard-data-server==0.7.2
tensorflow==2.16.1
tensorflow-io==0.37.1
tensorflow-io-gcs-filesystem==0.37.1
termcolor==2.4.0
tf_keras==2.16.0
tokenizers==0.19.1
torch==2.4.0
torchaudio==2.4.0
tqdm==4.66.4
transformers==4.43.3
typing_extensions==4.12.2
urllib3==2.2.2
Werkzeug==3.0.3
wheel==0.43.0
wrapt==1.16.0
Rocketknight1 commented 1 month ago

@dsame can you test the PR at #32301 and let me know if it fixes the issue for you?

You can install from the PR branch with pip install --upgrade git+https://github.com/huggingface/transformers.git@fix_whisper_tflite_export

github-actions[bot] commented 2 weeks ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.