huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.71k stars 26.22k forks source link

Add TensorFlow Whisper model for audio classification #21777

Open sanchit-gandhi opened 1 year ago

sanchit-gandhi commented 1 year ago

Feature request

The PR https://github.com/huggingface/transformers/pull/21754 adds the PyTorch version of WhisperForAudioClassification. It would be great to add the TensorFlow equivalent.

Motivation

Whisper is an encoder-decoder model for speech recognition. However, we can repurpose the model for other speech tasks, such as audio classification.

Audio classification is the task of mapping from an input speech sequence to a single class prediction. For more details, refer to the task page on the Hub: https://huggingface.co/tasks/audio-classification

For audio classification, we only require a single model output. Thus, we do not need the auto-regressive generation capacities of the Whisper decoder (which is used to generate a sequence of text tokens during speech recognition). Instead, we can just use the Whisper encoder to get hidden states, and add a classification head on top to make class label predictions.

This is analogous to using a Wav2Vec2 model for audio classification: the Wav2Vec2 encoder is used to get hidden states, and a classification head added on top to make class label predictions.

The PR https://github.com/huggingface/transformers/pull/21754 adds the PyTorch version of WhisperForAudioClassification. It required adding a projection layer and classification layer on top of the WhisperEncoder. For more details, refer directly to the pull request.

It would be great to add the TensorFlow equivalent of this model for cross-framework support.

The most difficult part of this PR will be getting the model tester to work. You can see from the PyTorch PR that we require a standalone tester for the audio classification model. This is because the original Whisper model is an encoder-decoder model, but the audio classification model is an encoder-only model. Thus, we require different testing logic.

Your contribution

Opening this one up to the community! If you're interested in tackling this, free to drop a comment in this thread and open a PR when you're ready. More than happy to answer any questions / queries about this integration!

OllieBroadhurst commented 1 year ago

Hey @sanchit-gandhi, if we're just using the encoder do you think a CTC head could also work, i.e. WhisperForCTC?

sanchit-gandhi commented 1 year ago

Hey @OllieBroadhurst! I don't think a an encoder-only Whisper model for speech recognition would be super practical since we'd then need an external language model to correct the phonetic errors made by the CTC model. IMO we're better off using the internal language model provided by the decoder in the original encoder-decoder architecture. The encoder-decoder model is trained end-to-end and on all of the Whisper pre-training data, so likely going to be better than any combination of CTC + LM we train ourselves

adit299 commented 1 year ago

Hello @OllieBroadhurst are you currently working on this? I would love to help out if I can/you need it. Otherwise, I would like to take a look at this issue.

OllieBroadhurst commented 1 year ago

Hi @adit299 ! I'm not so you can take it away!

adit299 commented 1 year ago

Great, will do!