Closed rafaelvp-db closed 1 year ago
hey @SeanNaren just made the PR, would be great if you could take a look and share some insights :)
Hi @rafaelvp-db thanks for making the PR!
There are quite a few pieces missing, hopefully, I can assist in helping you get the right things implemented!
Firstly I don't think your current code is entirely correct (unless I've made a mistake). The Wav2Vec model/dataset/tokenizer are completely different and should probably exist as new classes inherited from TaskTransformer
/TransformerDataModule
.
What I think would be a good idea is to get this blog post implemented into a TaskTransformer and a TransformerDataModule as you've already outlined: https://huggingface.co/blog/fine-tune-wav2vec2-english
This would involve
training_step
validation_step
and test_step
in the SpeechRecognitionTransformer
. Using the WER metric found in torchmetrics: https://torchmetrics.readthedocs.io/en/stable/text/word_error_rate.html?highlight=WERSpeechRecognitionDataModule
Overall I would assume something like this to work:
import pytorch_lightning as pl
from lightning_transformers.task.audio.speech_recognition import (
SpeechRecognitionDataConfig,
SpeechRecognitionDataModule,
SpeechRecognitionTransformer,
)
if __name__ == "__main__":
model = SpeechRecognitionTransformer("facebook/wav2vec2-base", ctc_loss_reduction="mean", vocab_file="vocab.json")
dm = SpeechRecognitionDataModule(
cfg=SpeechRecognitionDataConfig(
batch_size=1,
dataset_name="timit_asr",
),
tokenizer=model.tokenizer,
)
trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)
trainer.fit(model, dm)
Thanks for the guidelines @SeanNaren! Let me look into that.
Merging #251 (a3d8528) into master (9f25baa) will decrease coverage by
1%
. The diff coverage is58%
.:exclamation: Current head a3d8528 differs from pull request most recent head 45cbab7. Consider uploading reports for the commit 45cbab7 to get more accurate results
@@ Coverage Diff @@
## master #251 +/- ##
=====================================
- Coverage 75% 74% -1%
=====================================
Files 73 77 +4
Lines 1622 1682 +60
=====================================
+ Hits 1210 1245 +35
- Misses 412 437 +25
Hows it going @rafaelvp-db?
The code looks muuuch nicer, amazing job! Anything I can assist with? I notice that the example requires a vocab.json, I'm sure we can YOLO it and use the alphabet.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
Initial PR for ASR support