facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.22k stars 6.38k forks source link

How can I use a language model with wav2vec decoding #3455

Open coen22 opened 3 years ago

coen22 commented 3 years ago

❓ Questions and Help

What is your question?

I've been looking for an ASR solution to transcribe Dutch audio. I found that the new wav2vec 2.0 is working really well. However, I was wondering if I could use a base transformer language model (e.g. https://huggingface.co/pdelobelle/robbert-v2-dutch-base) to improve the results. I saw that you wrote To use the transformer language model, use --w2l-decoder fairseqlm. But, could you give me a little more guidance? Or could you point me to some resources where I can learn about this?

Code

audio, rate = librosa.load("untitled.wav", sr = 16000)

model_name = "facebook/wav2vec2-large-xlsr-53-dutch"
device = "cpu"
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"]'  # noqa: W605

processor = Wav2Vec2Processor.from_pretrained(model_name)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)

input_values = processor(audio, sampling_rate = rate, return_tensors = 'pt').input_values
logits = model(input_values).logits

predicted_ids = torch.argmax(logits, dim =-1)

#decode the audio to generate text (CAN I USE ROBERTA HERE?)
transcriptions = tokenizer.decode(predicted_ids[0])

#print result
print(transcriptions)

What have you tried?

I've only tried the CTC decoder through huggingface transformers in a python notebook on my local machine.

What's your environment?

AndrewMcDowell commented 3 years ago

The line:

predicted_ids = torch.argmax(logits, dim =-1)

Is essentially applying a greedy algorithm to the outputs of the wav2vec model, always taking the id that wav2vec has assigned highest probability.

If you want to feed the outputs to your own language model, you instead feed the logits as inputs to a second model and then apply the decoder to the outputs of that model.

If you added some layers to your model to take that as input you could do something like:

predicted_ids = your_language_model(logits)

Otherwise, you could try converting the input ids into text and then tokenizing directly to Roberta.

Edit:

This page gives the clearest description I've seen of how language models were used in training:

https://github.com/flashlight/flashlight/tree/master/flashlight/app/asr#beam-search-decoder

stale[bot] commented 3 years ago

This issue has been automatically marked as stale. If this issue is still affecting you, please leave any comment (for example, "bump"), and we'll keep it open. We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!