huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.57k stars 26.69k forks source link

WhisperFeatureExtractor padding='longest' cause whisper model to fail. #26241

Closed DavraYoung closed 1 year ago

DavraYoung commented 1 year ago

System Info

Hi, I have found huge memory consumption on my WhisperForAudioClassification model even when I supplied small audios, turns out WhisperFeatureExtractor always pads features to 30s chunks, even if my audios is 200ms long(I was doing per word speaker embedding).

Then I tried specifying padding='longest' which should not pad number of embeddings for audio, but turns out WhisperEncoder does not support dynamic number of embedding causing it to fail:

How I solved the problem:

# reduce embed_pos to the same shape as inputs_embeds
embed_pos = embed_pos[: inputs_embeds.shape[1], :]

https://github.com/huggingface/transformers/blob/e469be340673d1f6931eb22562efd2be7f5a5b8d/src/transformers/models/whisper/modeling_whisper.py#L902

Who can help?

@sanchit-gandhi

Information

Tasks

Reproduction

How to reproduce the issue:

from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor
print("np_audio.shape=", np_audio.shape) # np_audio.shape= (16000,) 16k samples in 1s
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
features = feature_extractor(np_audio, sampling_rate=16000, return_tensors="pt", padding="longest").input_features
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
print("features.shape=", features.shape) # features.shape= torch.Size([1, 80, 100]). batch_size=1, feature_size=80, seq_len=100
model(features)

In my case I supplied 1000ms audio, that has 16k samples the model has thrown an error:

RuntimeError: The size of tensor a (50) must match the size of tensor b (1500) at non-singleton dimension 1
File C:\projects\stt\venv\lib\site-packages\transformers\models\whisper\modeling_whisper.py:902, in WhisperEncoder.forward(self, input_features, attention_mask, head_mask, output_attentions, output_hidden_states, return_dict)
    899 inputs_embeds = inputs_embeds.permute(0, 2, 1)
    900 embed_pos = self.embed_positions.weight
--> 902 hidden_states = inputs_embeds + embed_pos
    903 hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
    905 encoder_states = () if output_hidden_states else None

Expected behavior

I expect the model output me the response without error. At least when working with classification head

sanchit-gandhi commented 1 year ago

Hey @DavraYoung! The behaviour you've encountered here is the way the Whisper model gets around dealing with padded/truncated inputs: all input audios are padded/truncated to 30 seconds, regardless of their length, before being converted to log-mel spectrogram inputs. The model is then trained without an attention mask. Instead, it learns to ignore the padded inputs from the spectrogram inputs directly.

At inference time, we have to match the paradigm the model was trained on, i.e. always pad/truncate audios to 30 seconds. This is why the feature extractor and positional embeddings always expect log-mel spectrograms with a sequence length of 1500, which corresponds to 30 seconds of audio input.

You'll find that the OpenAI Whisper implementation also forces the inputs to always be 30 seconds. The Transformers' implementation thus matches this for strict one-to-one equivalence.

If you're interested in passing shorter log-mels, you can set the corresponding attribute in the feature extractor, and slice the positional embeddings to the required length.

Here's a codesnippet on how you can achieve this, slicing to a sequence length of 500 (corresponding to 10 seconds of audio input): https://github.com/sanchit-gandhi/codesnippets/blob/main/whisper-reduce-context.ipynb

There's a justification for why we don't slice on-the-fly here: https://github.com/huggingface/transformers/issues/25744#issuecomment-1703112076

sanchit-gandhi commented 1 year ago

Hey @DavraYoung - did the above explanation help with tackling your issue?

DavraYoung commented 1 year ago

Hi @sanchit-gandhi If you mean WhisperCTC model implementation, then no, it didn't help. Though I tried training it only with padding="longest" and with modified Encoder. But I think it should not affect the accuracy much

I will have time to come back to the experiments with this model in 2 weeks

sanchit-gandhi commented 1 year ago

Is it ok if we close the issue given that we're keeping the Whisper input context length fixed? We can continue to discuss Whisper CTC on the other dedicated issue thread!