Open sanchit-gandhi opened 1 year ago
Hey @sanchit-gandhi, if we're just using the encoder do you think a CTC head could also work, i.e. WhisperForCTC
?
Hey @OllieBroadhurst! I don't think a an encoder-only Whisper model for speech recognition would be super practical since we'd then need an external language model to correct the phonetic errors made by the CTC model. IMO we're better off using the internal language model provided by the decoder in the original encoder-decoder architecture. The encoder-decoder model is trained end-to-end and on all of the Whisper pre-training data, so likely going to be better than any combination of CTC + LM we train ourselves
Hello @OllieBroadhurst are you currently working on this? I would love to help out if I can/you need it. Otherwise, I would like to take a look at this issue.
Hi @adit299 ! I'm not so you can take it away!
Great, will do!
Feature request
The PR https://github.com/huggingface/transformers/pull/21754 adds the PyTorch version of
WhisperForAudioClassification
. It would be great to add the TensorFlow equivalent.Motivation
Whisper is an encoder-decoder model for speech recognition. However, we can repurpose the model for other speech tasks, such as audio classification.
Audio classification is the task of mapping from an input speech sequence to a single class prediction. For more details, refer to the task page on the Hub: https://huggingface.co/tasks/audio-classification
For audio classification, we only require a single model output. Thus, we do not need the auto-regressive generation capacities of the Whisper decoder (which is used to generate a sequence of text tokens during speech recognition). Instead, we can just use the Whisper encoder to get hidden states, and add a classification head on top to make class label predictions.
This is analogous to using a Wav2Vec2 model for audio classification: the Wav2Vec2 encoder is used to get hidden states, and a classification head added on top to make class label predictions.
The PR https://github.com/huggingface/transformers/pull/21754 adds the PyTorch version of
WhisperForAudioClassification
. It required adding a projection layer and classification layer on top of theWhisperEncoder
. For more details, refer directly to the pull request.It would be great to add the TensorFlow equivalent of this model for cross-framework support.
The most difficult part of this PR will be getting the model tester to work. You can see from the PyTorch PR that we require a standalone tester for the audio classification model. This is because the original Whisper model is an encoder-decoder model, but the audio classification model is an encoder-only model. Thus, we require different testing logic.
Your contribution
Opening this one up to the community! If you're interested in tackling this, free to drop a comment in this thread and open a PR when you're ready. More than happy to answer any questions / queries about this integration!