kensho-technologies / pyctcdecode

A fast and lightweight python-based CTC beam search decoder for speech recognition.
Apache License 2.0
419 stars 89 forks source link

Using Nemo with BPE models #9

Closed pehonnet closed 3 years ago

pehonnet commented 3 years ago


Great repo! The tutorial for nemo models is working fine, but it seems when going to a BPE model (like the recent conformer one available in nemo), there is some trick changing the alphabet done in nemo, but not in pyctcdecode.

When trying to run something similar to the nemo notebook all the tokens seem shifted that's why I guess it's related to this token offset.


poneill commented 3 years ago

Hi, thanks for this. Could you post a MWE?

pehonnet commented 3 years ago

Sure, here is how to reproduce similar behavior:

  1. create a kenlm LM with Nemo script (which does the token offset). Here we use fisher text data but it could be something else. The script used is with the following command line:

    python scripts/asr_language_modeling/ngram_lm/ \
    --nemo_model_file examples/asr/downloaded_models/stt_en_conformer_ctc_large.nemo \
    --train_file data/fisher/training_text_normalized_lowercase.txt \
    --kenlm_bin_path decoders/kenlm/buildbin/bin \
    --kenlm_model_file lm_kenlm/fisher_lm_4gram_conformer.bin \
    --ngram_length 4
  2. in a notebook, in pyctcdecode similarly to the tutorial run:

    import nemo.collections.asr as nemo_asr
    from nemo.collections.asr.metrics.wer import word_error_rate
    import kenlm
    import gzip
    import os, shutil, wget
    from pyctcdecode import build_ctcdecoder
    import multiprocessing
    import pandas as pd
    import numpy as np
    import kenlm_utils
    import torch
    # get a single audio file

############################ asr_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name='stt_en_conformer_ctc_large') ############################


transcribe audio to logits

logits = asr_model.transcribe(["1919-142785-0028.wav"], logprobs=True)[0] probs = kenlm_utils.softmax(logits) preds = np.argmax(probs, axis=1) preds_tensor = torch.tensor(preds, device='cpu').unsqueeze(0) pred_text = asr_model._wer.ctc_decoder_predictions_tensor(preds_tensor)[0] print(pred_text) ############################ OUT: boil them before they are put into the soup or other dish they may be intended for


look at the alphabet of our model defining the labels for the logit matrix we just calculated

asr_model_large.decoder.vocabulary ############################ OUT: ['', '▁', '##s', '##t', '##e', '##a', '##i', '##d', '##o', '##m', '##l', 'a', 'the', '##n', '##p', '##y', '##u', '##er', '##h', '##re', '##c', '##r', 'i', 's', '##ing', "##'", 'to', '##k', '##g', 'and', '##f', '##an', '##ed', '##w', 'of', 'w', 'in', 'you', '##or', '##v', 'f', 'b', '##in', '##b', '##ar', '##en', '##ve', '##es', 'it', 'that', '##al', '##ll', 'be', 'co', 'he', '##le', '##ch', 'do', 'we', '##ly', '##it', 'g', '##ent', '##ur', '##on', '##ic', 'c', 'on', '##un', '##th', 'ha', 'was', 't', 'is', 'for', 'ma', '##ri', 'know', 're', '##ra', 'th', '##us', '##ce', 'but', 'go', '##ro', 'mo', 'st', 'me', 'they', 'so', 'yeah', 'ca', '##ir', 'have', 'like', 'di', 'ho', '##ion', 'with', 'lo', 'not', 'no', 'de', 'ne', 'pa', '##ation', 'this', 'what', 'some', 'uh', 'pro', 'my', '##j', 'his', 'ex', 'just', 'um', 'po', 'la', 'think', 'she', '##ough', 'from', '##q', 'out', '##x', '##z']


build the decoder and decode the logits

kenlm_model = kenlm.Model('fisher_lm_4gram_conformer.bin') labels = asr_model_large.decoder.vocabulary

decoder = build_ctcdecoder(asr_model_large.decoder.vocabulary, kenlm_model, is_bpe=True)

decoded = decoder.decode(logits) print(decoded) ############################ OUT: z fzdzazmz azozllz andzdhz mezlhz nysz wzsdz theyzynz z youz zdunzuz likez erz mez forzpzllz wzszarzizanz isz

poneill commented 3 years ago

Thanks, this is very helpful. I believe this is a known issue with our unk handling in BPE alphabets, addressed by Will update.

gkucsko commented 3 years ago

Hi, yeah sorry there are a lot of bpe conventions. we will update the vocabulary parsing soon in #4. In the meantime the below manual conversion should work for you:

vocab = asr_model.decoder.vocabulary.copy()
vocab = vocab[:2] + [c[2:] if c[:2] == "##" else "▁" + c for c in vocab[2:]]
pehonnet commented 3 years ago

That works fine, thanks a lot!