Open DavraYoung opened 1 year ago
Regarding the accuracy and if the model actually work: I was able to achieve decent accuracy (WER: 7% on my test dataset with only 10% of my dataset(100k audios), starting from WhisperForConditionalGeneration Encoder checkpoint).
Adding as a feature request FYI @sanchit-gandhi
That's very cool @DavraYoung! Did you build the CTC tokenizer yourself as well? And how did WhisperForCTC
compare to WhisperForConditionalGeneration
when fine-tuned on your dataset? We could run very fast training by freezing the entire encoder block and only fine-tuning the CTC head š
@sanchit-gandhi hi, Regarding tokenizer: I used wav2vec2 tokenizer with custom vocab(Latin lowercase alphabet + ' , like in wav2vec2 finetuning tutorial.
Regarding performance: I cannot directly compare the models right now, since I trained WhisperForConditionalGeneration on slightly different dataset(some entries are not present) and some entries from my current validation dataset were present in training data of WhisperForConditionalGeneration.
Regarding the actual performance on the unseen dataset, I think, WhisperForConditionalGeneration is much better than any CTC based model, especially when given previous context/prompt, but it requires good dataset with long audios with enough previous text context. CTC head based models on the other hand does not require diverse lengths dataset and may operate on smaller audios, like in my current dataset. Thats why I was investigating CTC head with WhisperEncoder
If its needed I can spend some time on training whisper-ctc-medium-960h librispeech English model with Wav2vec2-base-960h tokenizer vocab.
After playing around with such model, I found issues with degraded performance on validation dataset.
My training setup:
openai/whisper-large
encoder + new ctc head.
400 hours of uzbek short audios.
Modified encoder with partial positional encodings
I couldnt make freezed version to converge. So I switched to unfreezed version training. Model was showing good results on 128 Testing samples. After training for 2 epochs I ran the model on validation dataset and unfortunately it has shown worse results that wav2vec2forCtc on the same 2048 clips dataset.
I modified the positional embeddings in encoder:
inputs_embeds = inputs_embeds.permute(0, 2, 1)
embed_pos = self.embed_positions.weight
# reduce embed_pos to the same shape as inputs_embeds
embed_pos = embed_pos[: inputs_embeds.shape[1], :]
probably they could cause the issue.
During training I observe overall good loss: 0.03-0.07 (3-7%), but I repeatedly see loss jumping to 20-40%. Checked audios with that sample entries, they are fine. It seems like model fails to understand certain features.
Here is the list of my models trained on the same dataset(except phone versions). Phone versions were trained on top of general model with random audio sample rate reduction + 10 hours of real phone data | Model Name | Average CER | Average WER |
---|---|---|---|
general-wav2vec2-1b-overfited-11 | 0.004 | 0.029 | |
general-wav2vec2-2b-10-07-2023 | 0.008 | 0.054 | |
general-medium-wav2vec2-conformer | 0.014 | 0.091 | |
mixed-wav2vec-2b-01-09-2023 | 0.014 | 0.086 | |
general-wav2vec2-medium-14-09-2023 | 0.023 | 0.139 | |
general-wav2vec2-small-13-09-2023 | 0.056 | 0.305 | |
general-wav2vec2-100m-11-09-2023 | 0.169 | 0.723 | |
general-whisper-large-ctc-18-09-2023 | 0.178 | 0.197 | |
phone-whisper-large-ctc-18-09-2023 | 0.187 | 0.249 | |
general-whisper-medium-ctc-18-09-2023 | 0.191 | 0.265 | |
general-whisper-ultra-ctc-18-09-2023 | 0.19 | 0.26 |
Hey @DavraYoung! Thanks for explaining a bit more about how you're using this model. My only concern with adding this to the Transformers' library is that it's a bit of an 'un-official' implementation? That is to say, there are no official pre-trained weights available for the model.
Having pre-trained weights and reference code are general pre-requisites for adding a model to Transformers. My view is that Whisper for CTC is a nice extension of the Whisper Transformers code to a different modelling paradigm. However, it's not one that is natively supported in the original Whisper library, or contains official weights on the Hub. Hence, it could be a nice example that you share in a standalone repo of your own? Or showcased on the Hub with your fine-tuned weights?
Regarding CTC vs Enc-Dec, there are some nice comparisons in the ESB paper: https://arxiv.org/abs/2210.13352
Notably, CTC performs worse than Enc-Dec with punctuation and casing, and makes more frequent spelling errors. It's these aspects of robustness that make Whisper Enc-Dec such an appealing model for ASR, so I think promoting the Enc-Dec architecture is sensible here.
It's hard to comment on your training results without any information / statistics on the training data or task, but they seem to suggest that Wav2Vec2 is a more promising approach for encoder only CTC decoding.
@DavraYoung Thanks for sharing great work! can I use utilize your codes to compare with other CTC based models, like HuBERT / WavLM / XLS-R / MMS ?
as @sanchit-gandhi mentioned, comparing the results with other CTC based models will give the WhisperEncoderForCTC a fairer comparison!
@cjw414 yes, no problem. I would also recommend to change the encoder embeddings to work properly with 'longest' padding.
# reduce embed_pos to the same shape as inputs_embeds
embed_pos = embed_pos[: inputs_embeds.shape[1], :]
Otherwise you will need to use WhisperFeatureExtractor with padded audio length to 30s, which may impact training speed, if your average audio length is short.
thx, but I had my personal experience of Whisper not functioning well if it given short audios without padding. maybe I did something wrong, but I might just start with paddding with 30s
Will notice you if I get some findings!
I also found that in my Distil-Whisper experiments padding to 30s worked better than non-padding! Likely because the model is pre-trained on such a vast amount of data that it ends up working so strongly on padded inputs
I also found that in my Distil-Whisper experiments padding to 30s worked better than non-padding! Likely because the model is pre-trained on such a vast amount of data that it ends up working so strongly on padded inputs
yup probably this might be one of the obstacles for Whisper from slow inference speed btw, I really liked your distil-whisper!!
Feature request
Request to add WhisperForCTC model.
Motivation
it would be cool if we had custom WhisperForCTC with Whisper encoder and ctc head, just like Wav2vec2, but since whisper is based on mel spectograms, I think it may bring better results.
Your contribution
Here is my implementation, I mostly copied from Wav2vec2ForCTC NOTE: there is TODO that needs to be resolved, I didnt test that part, since whisper operates with transposed hidden_states shape