huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.97k stars 26.29k forks source link

Prompting TFWhisperForConditionalGeneration leads to runtime crahses #32880

Open Manuel030 opened 3 weeks ago

Manuel030 commented 3 weeks ago

System Info

Who can help?

@sanchit-gandhi @gante

Information

Tasks

Reproduction

import librosa
from transformers import WhisperProcessor, WhisperTokenizer, WhisperFeatureExtractor, TFWhisperForConditionalGeneration

model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")

# get some dummy data
audio, sr = librosa.load("audio/samples_jfk.wav", sr=16000, mono=True)
inputs = feature_extractor(audio, sampling_rate=sr, return_tensors="tf")
input_features = inputs.input_features
prompts = processor.get_prompt_ids("some random words", return_tensors="tf")

outputs = model.generate(
    input_features,
    return_dict_in_generate=True,
    prompt_ids = prompts,
)

Expected behavior

Running the code above leads to:

AttributeError: EagerTensor object has no attribute 'tolist'. If you are looking for numpy-related methods, please run the following: from tensorflow.python.ops.numpy_ops import np_config np_config.enable_numpy_behavior()

It seems that prompts are not available to use in graph mode (which I need). Also, a workaround using decoder_input_ids leads o other issues.

amyeroberts commented 3 weeks ago

cc @gante @Rocketknight1

Rocketknight1 commented 3 weeks ago

@Manuel030 have you tried return_tensors="np" instead?

Manuel030 commented 3 weeks ago

Thanks @Rocketknight1, I don't have numpy available when executing in graph mode. I would expect the generate pass to be compatible with tensorflow's graph execution mode.

Manuel030 commented 3 weeks ago

Also, the doc string states it should be a tf.Tensor.

Rocketknight1 commented 3 weeks ago

Got it - can you paste the entire traceback so I can figure out where it's happening?

Manuel030 commented 3 weeks ago

Sure:

  File "/home/manuel/Projects/whisper-finetune/issue.py", line 14, in <module>
    outputs = model.generate(
  File "/home/manuel/Projects/whisper-finetune/venv/lib/python3.10/site-packages/transformers/models/whisper/modeling_tf_whisper.py", line 1646, in generate
    prompt_ids = prompt_ids.tolist()
  File "/home/manuel/Projects/whisper-finetune/venv/lib/python3.10/site-packages/tensorflow/python/framework/ops.py", line 440, in __getattr__
    raise AttributeError(
AttributeError: EagerTensor object has no attribute 'tolist'. 
        If you are looking for numpy-related methods, please run the following:
        from tensorflow.python.ops.numpy_ops import np_config
        np_config.enable_numpy_behavior()
Rocketknight1 commented 3 weeks ago

cc @gante @sanchit-gandhi, the relevant code is

prompt_ids = prompt_ids.tolist()
decoder_start_token_id, *text_prompt_ids = prompt_ids

which is failing in graph mode when the input is tf.Tensor and not np.ndarray. A simple workaround would be something like this:

if isinstance(prompt_ids, np.ndarray):
    prompt_ids = prompt_ids.tolist()
else:
    prompt_ids = prompt_ids.numpy().tolist()

but I'll wait for @gante's feedback here. The reason this is needed in the first place is because @Manuel030 wants to compile the generation loop, presumably for XLA/export. Is that possible for TFWhisper, or will we just run into other problems if we fix this line?

gante commented 3 weeks ago

đź‘‹

Under the hood, TFWhisper's generate calls the OG generate, so it should be compileable! However, I'm not sure if it can be compiled when prompt_ids is set (that code path has things like enumerate, which is often incompatible).

Regardless of whether it fixes the XLA use case, the suggested change LGTM @Rocketknight1 đź‘Ť

Manuel030 commented 3 weeks ago

Unfortunately, patching the high-level generate as suggested by @Rocketknight1 is not successful. My use case is an export to the tflite format.

Rocketknight1 commented 3 weeks ago

@manuel030 I'm not sure we have a good solution in that case - you might have to make some changes to the generate() function in TFWhisper, like the one I made above, until you can get it to compile successfully. If you do, please open a PR to add the changes to the codebase!