huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.31k stars 26.85k forks source link

Bug in WhisperTokenizer batch_decode, when set `skip_special_tokens=True` for FlaxWhisper model output #32936

Closed hannan72 closed 1 month ago

hannan72 commented 2 months ago

System Info

Who can help?

@sanchit-gandhi

Information

Tasks

Reproduction

I use this piece of code to deploy a sample audio file on Flax Whisper-large-v3 model with Jax.

from transformers import FlaxWhisperForConditionalGeneration, WhisperTokenizer
from scipy.io import wavfile
import jax
import jax.numpy as jnp
import numpy as np
import torch
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
model_path="openai/whisper-large-v3"
audio_file_path = ".path/to/audio/audio_file.wav"
samplerate, data_waveform = wavfile.read(audio_file_path)

tokenizer = WhisperTokenizer.from_pretrained(model_path)
with torch.no_grad():
    model = FlaxWhisperForConditionalGeneration.from_pretrained(model_path, dtype=jnp.float16, from_pt=True)

jit_generate = jax.jit(model.generate, static_argnames=["max_length", "language", "task"])

samplerate, data_waveform = wavfile.read(audio_file_path)
data_waveform = (data_waveform)/32768.0
input_features = processor(data_waveform, padding="max_length", sampling_rate=16000, return_tensors="pt").input_features
input_features = jnp.array(input_features, dtype=jnp.float16)
pred_ids = jit_generate(input_features, max_length=128, language='<|de|>', task ="transcribe")
print(tokenizer.batch_decode(pred_ids.sequences, skip_special_tokens=True))

It was working properly until version 4.42.4 of transformers, but from version 4.43.0 of transformers, it raises an error in the last line of the code (batch_decode):

File "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py", line 3994, in batch_decode
    return [
  File "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py", line 3995, in <listcomp>
    self.decode(
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/tokenization_whisper.py", line 692, in decode
    filtered_ids = self._preprocess_token_ids(
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/tokenization_whisper.py", line 637, in _preprocess_token_ids
    token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/tokenization_whisper.py", line 860, in _strip_prompt
    if not token_ids:
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/array.py", line 258, in __bool__
    core.check_bool_conversion(self)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 654, in check_bool_conversion
    raise ValueError("The truth value of an array with more than one element"
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

However in the batch_decode method, if I disable the skip_special_tokens arg (set it to False), it raises no error but return lots of special chars.

Expected behavior

It is expected to return list of strings in the result of batch_decode method, as same as how it works until version 4.42.4 of transformers

hannan72 commented 2 months ago

Any updates @sanchit-gandhi @ArthurZucker ?

LysandreJik commented 2 months ago

@eustlb, would you mind looking at this issue if you have some bandwidth?

hannan72 commented 2 months ago

I think the error is due to the issue of checking jax arrays with not in tokenization_whisper.py code: https://github.com/huggingface/transformers/blob/d1f39c484d8347aa7b3170ea250a1e8f3bdfdf31/src/transformers/models/whisper/tokenization_whisper.py#L852 image

It is OK to check token_ids if it is torch or np, but for the cases that it is a JAX array, it is not possible to directly use a JAX array in a boolean context (e.g., if not jax_array:) so jax raises error for such cases:

  File "/usr/local/lib/python3.10/dist-packages/jax/_src/array.py", line 258, in __bool__
    core.check_bool_conversion(self)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 654, in check_bool_conversion
    raise ValueError("The truth value of an array with more than one element"
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
ArthurZucker commented 2 months ago

feel free to open a PR for a fix then!

hannan72 commented 2 months ago

I created a PR: https://github.com/huggingface/transformers/pull/33151

@ArthurZucker Please review and merge it

github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

ArthurZucker commented 1 month ago

Closing as it was merged!