facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.47k stars 6.41k forks source link

[Wav2Vec2] Wav2Vec2Conformer Fine-Tuned seems to give Gibberish on Librispeech example #4356

Open patrickvonplaten opened 2 years ago

patrickvonplaten commented 2 years ago

🐛 Bug

Wav2Vec2's newly released fine-tuned conformer checkpoints (see here) don't produce reasonable results on an example of Librispeech.

I'm not sure if the model requires a different

To Reproduce

  1. Download 960h fine-tuned checkpoint: wget https://dl.fbaipublicfiles.com/fairseq/conformer/wav2vec2/librilight/LL_relpos_PT_960h_FT.pt

  2. Download Librispeech Dict: wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt

  3. Load a sample of the Librispeech clean dataset for inference. You can load a dummy sample via the Hugging Face Hub

pip install datasets

from datasets import load_dataset
libri_dummy = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")

# check out the dataset
print(libri_dummy)
  1. Run a forward pass
import torch
import fairseq

input_sample = torch.tensor(libri_dummy[0]["audio"]["array"])[None, :]

# normalize
input_sample = torch.nn.functional.layer_norm(input_sample, input_sample.shape)

model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(['LL_relpos_PT_960h_FT.pt'], arg_overrides={"data": "/path/to/folder/of/dict"})
model = model[0]
model.eval()

logits = model(source=input_sample, padding_mask=None)["encoder_out"]
  1. Decode the prediction

The output is a tensor of shape [seq_len, 1, vocab_size]. We are interested in the most likely token for each time step. So we can take the argmax:

predicted_ids = torch.argmax(logits[:, 0], dim=-1)
  1. Now we'll create our own decoder based on the dict we downloaded previously to decode the result (it's just the decoder put into json format)
json_dict = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3, "|": 4, "E": 5, "T": 6, "A": 7, "O": 8, "N": 9, "I": 10, "H": 11, "S": 12, "R": 13, "D": 14, "L": 15, "U": 16, "M": 17, "W": 18, "C": 19, "F": 20, "G": 21, "Y": 22, "P": 23, "B": 24, "V": 25, "K": 26, "'": 27, "X": 28, "J": 29, "Q": 30, "Z": 31}

and create a decoder

import numpy as np
from itertools import groupby

class Decoder:
    def __init__(self, json_dict):
        self.dict = json_dict
        self.look_up = np.asarray(list(self.dict.keys()))

    def decode(self, ids):
        converted_tokens = self.look_up[ids]
        fused_tokens = [tok[0] for tok in groupby(converted_tokens)]
        output = ' '.join(''.join(''.join(fused_tokens).split("<s>")).split("|"))
        return output

Now we can decode the output and compare it to the correct output:

decoder = Decoder(json_dict=json_dict)
print("Prediction: ", decoder.decode(predicted_ids))

As we can see the prediction is wrong:

Prediction:  AY N N N VN V'V'V'IRSIMG KMNJB TPEPEMEDYR RGQTQ'OB 'HNJ<unk>TQURIEMJ' 'B'F' TM'VS'NEMDJH DSB'CNSTITE  RYKYRSITPSV'DYNY' M'SOEPUGSYDYH'BYTITIPKV UFMQ'W'YJRDH' MVY'SGM'GNE F'YZH'U IFB'N' ' A V YA IN

The correct transcription is:

print(libri_dummy[0]["text"]
'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'

Also from looking at the predicted ids of the model (the argmax logits):

tensor([ 7, 22,  4,  4,  9,  4,  4,  9,  9,  9,  4,  4,  9,  4,  4,  4, 25,  9,
         4,  4,  4,  4,  4, 25, 27, 25, 25, 25, 27, 25, 27, 10, 13,  0, 12, 10,
        17, 21, 21, 21,  0,  4, 26, 17, 17, 17,  9,  9,  9,  9, 29,  0, 24,  4,
         6, 23,  5, 23,  5, 17, 17,  5,  5,  5, 14, 22, 22, 13, 13,  4,  4, 13,
        21, 21, 21, 21, 30,  6, 30, 27,  8, 24,  0,  4, 27, 11,  9, 29, 29,  3,
         6, 30, 16, 13, 10, 10,  5, 17, 29, 29, 29, 27,  4,  4, 27, 24, 24, 27,
        20, 27, 27, 27,  4,  4,  6,  0, 17, 27, 25, 12, 12, 27,  9,  5, 17, 14,
        29, 11,  4, 14, 12, 12, 24, 27, 19, 19,  9, 12,  6, 10, 10, 10,  6,  6,
         5,  4,  4,  0,  4, 13, 22, 22, 22, 22, 26, 22, 13, 13, 12, 10,  0,  6,
        23, 12,  0, 25, 27, 14, 22,  9, 22, 27, 27,  4, 17, 27, 12,  8,  8,  5,
         5, 23, 16, 21, 12, 22, 14, 22, 11, 27, 24, 22,  6, 10,  6, 10, 10, 10,
        23, 26, 25, 25, 25,  0,  4, 16, 20, 17, 30, 27, 27, 18, 27, 22, 29, 13,
        14, 11, 11,  0, 27,  4,  4, 17, 17, 25, 22, 22, 27, 12, 21, 17, 27, 27,
        27, 21,  9,  5,  5,  5,  4,  0,  0, 20, 27, 22, 22, 31, 11, 11, 11, 27,
        16, 16,  4, 10,  0,  0,  0, 20, 24, 27, 27, 27, 27, 27, 27, 27, 27, 27,
        27,  9,  9,  9, 27,  4,  4,  4, 27,  4,  4,  4,  7,  4, 25,  4, 22,  7,
         4,  4, 10,  9])

It does seems like there is something wrong with the model and not just the dictionary. There is no overwhelmingly present id which could represent silence.

Expected behavior

The model should work correctly here.

Environment

Additional context

patrickvonplaten commented 2 years ago

Gently pinging @sravyapopuri388 and @alexeib

sravyapopuri388 commented 2 years ago

Thanks for the ping @patrickvonplaten. I will look into this and get back to you.

patrickvonplaten commented 2 years ago

Hey @sravyapopuri388, any updates on this by any chance? :-)

sravyapopuri388 commented 2 years ago

Hi, I tried decoding the model using the following command from the wiki and the results are good. Could you please recheck your setup. Thanks!

$subset=dev_other
python3 examples/speech_recognition/infer.py $DATA_DIR --task audio_finetuning \
--nbest 1 --path $CKPT --gen-subset $SUBSET \
--results-path  $result_path --w2l-decoder viterbi \
--criterion ctc --labels $LABELS --max-tokens 0   \
--post-process letter --word-score -1 --sil-weight 0  --batch-size 1
patrickvonplaten commented 2 years ago

Hey @sravyapopuri388,

Sorry I don't have access to /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw or dev_other or kenlm.bin so it's not possible for me to run this command.

If possible, it would be great if you could post a command that shows how the model gives good results on a single audio file without a language model - this would be super helpful for the community to use these models.

Could you maybe check the above commands to see if CTC without a language model works correctly?

patrickvonplaten commented 2 years ago

Do you know which dictionary was used for the model?

sravyapopuri388 commented 2 years ago

Hi @patrickvonplaten, updated the above command to not use language model and still works correctly. I used the dictionary open sourced in the wav2vec README here

To run with a single audio file, you can format it in the wav2vec data format and run the above command.

patrickvonplaten commented 2 years ago

Hey @sravyapopuri388, thanks for the pointers - I used the wrong dictionaries :sweat_smile: . Decoding now works as expected!

rahulshivajipawar commented 2 years ago

Hey @patrickvonplaten, Can you post the command which worked for you ? Thanks.

patrickvonplaten commented 2 years ago

The very first command actually worked correctly @rahulshivajipawar

There is also a HF implementation now: https://github.com/huggingface/transformers/pull/16812

BakingBrains commented 1 year ago

@patrickvonplaten Hello Is their any pretraining and finetuning notebook for Wav2Vec-Conformer using transformers