kensho-technologies / pyctcdecode

A fast and lightweight python-based CTC beam search decoder for speech recognition.
Apache License 2.0
419 stars 89 forks source link

Using Nemo with BPE models #9

Closed pehonnet closed 3 years ago

pehonnet commented 3 years ago

Hello,

Great repo! The tutorial for nemo models is working fine, but it seems when going to a BPE model (like the recent conformer one available in nemo), there is some trick changing the alphabet done in nemo, but not in pyctcdecode.

https://github.com/NVIDIA/NeMo/blob/acbd88257f20e776c09f5015b8a793e1bcfa584d/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram.py#L112

When trying to run something similar to the nemo notebook all the tokens seem shifted that's why I guess it's related to this token offset.

Thanks

poneill commented 3 years ago

Hi, thanks for this. Could you post a MWE?

pehonnet commented 3 years ago

Sure, here is how to reproduce similar behavior:

  1. create a kenlm LM with Nemo script (which does the token offset). Here we use fisher text data but it could be something else. The script used is https://github.com/NVIDIA/NeMo/blob/main/scripts/asr_language_modeling/ngram_lm/train_kenlm.py with the following command line:

    python scripts/asr_language_modeling/ngram_lm/train_kenlm.py \
    --nemo_model_file examples/asr/downloaded_models/stt_en_conformer_ctc_large.nemo \
    --train_file data/fisher/training_text_normalized_lowercase.txt \
    --kenlm_bin_path decoders/kenlm/buildbin/bin \
    --kenlm_model_file lm_kenlm/fisher_lm_4gram_conformer.bin \
    --ngram_length 4
  2. in a notebook, in pyctcdecode similarly to the tutorial run:

    
    ############################
    import nemo.collections.asr as nemo_asr
    from nemo.collections.asr.metrics.wer import word_error_rate
    import kenlm
    import gzip
    import os, shutil, wget
    from pyctcdecode import build_ctcdecoder
    import multiprocessing
    import pandas as pd
    import numpy as np
    import kenlm_utils
    import torch
    # get a single audio file
    !wget https://dldata-public.s3.us-east-2.amazonaws.com/1919-142785-0028.wav
    ############################

############################ asr_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name='stt_en_conformer_ctc_large') ############################

############################

transcribe audio to logits

logits = asr_model.transcribe(["1919-142785-0028.wav"], logprobs=True)[0] probs = kenlm_utils.softmax(logits) preds = np.argmax(probs, axis=1) preds_tensor = torch.tensor(preds, device='cpu').unsqueeze(0) pred_text = asr_model._wer.ctc_decoder_predictions_tensor(preds_tensor)[0] print(pred_text) ############################ OUT: boil them before they are put into the soup or other dish they may be intended for

############################

look at the alphabet of our model defining the labels for the logit matrix we just calculated

asr_model_large.decoder.vocabulary ############################ OUT: ['', '▁', '##s', '##t', '##e', '##a', '##i', '##d', '##o', '##m', '##l', 'a', 'the', '##n', '##p', '##y', '##u', '##er', '##h', '##re', '##c', '##r', 'i', 's', '##ing', "##'", 'to', '##k', '##g', 'and', '##f', '##an', '##ed', '##w', 'of', 'w', 'in', 'you', '##or', '##v', 'f', 'b', '##in', '##b', '##ar', '##en', '##ve', '##es', 'it', 'that', '##al', '##ll', 'be', 'co', 'he', '##le', '##ch', 'do', 'we', '##ly', '##it', 'g', '##ent', '##ur', '##on', '##ic', 'c', 'on', '##un', '##th', 'ha', 'was', 't', 'is', 'for', 'ma', '##ri', 'know', 're', '##ra', 'th', '##us', '##ce', 'but', 'go', '##ro', 'mo', 'st', 'me', 'they', 'so', 'yeah', 'ca', '##ir', 'have', 'like', 'di', 'ho', '##ion', 'with', 'lo', 'not', 'no', 'de', 'ne', 'pa', '##ation', 'this', 'what', 'some', 'uh', 'pro', 'my', '##j', 'his', 'ex', 'just', 'um', 'po', 'la', 'think', 'she', '##ough', 'from', '##q', 'out', '##x', '##z']

############################

build the decoder and decode the logits

kenlm_model = kenlm.Model('fisher_lm_4gram_conformer.bin') labels = asr_model_large.decoder.vocabulary

decoder = build_ctcdecoder(asr_model_large.decoder.vocabulary, kenlm_model, is_bpe=True)

decoded = decoder.decode(logits) print(decoded) ############################ OUT: z fzdzazmz azozllz andzdhz mezlhz nysz wzsdz theyzynz z youz zdunzuz likez erz mez forzpzllz wzszarzizanz isz

poneill commented 3 years ago

Thanks, this is very helpful. I believe this is a known issue with our unk handling in BPE alphabets, addressed by https://github.com/kensho-technologies/pyctcdecode/pull/4. Will update.

gkucsko commented 3 years ago

Hi, yeah sorry there are a lot of bpe conventions. we will update the vocabulary parsing soon in #4. In the meantime the below manual conversion should work for you:

vocab = asr_model.decoder.vocabulary.copy()
vocab = vocab[:2] + [c[2:] if c[:2] == "##" else "▁" + c for c in vocab[2:]]
pehonnet commented 3 years ago

That works fine, thanks a lot!