flashlight / text

Text utilities, including beam search decoding, tokenizing, and more, built for use in Flashlight.
MIT License
64 stars 14 forks source link

Migration from OpenSeq2Seq decoder to Flashlight #68

Open Squire-tomsk opened 1 year ago

Squire-tomsk commented 1 year ago

Hello,

I am trying to migrate my ASR model from OpenSeq2Seq decoder to Flashlight. Currently, I am using Nemo Conformer large as the acoustic model, which is tuned on my data. I also use KenLM as the language model, which is trained using this script provided by NeMo. However, OpenSeq2Seq decoder does not support BPE tokens, so subwords are mapped to chars and KenLM is trained on these chars.

To decode log_probs from Conformer, I am using the following code:

tokenizer = load_sentencepiece_tokenizer("tokenizer.model")
tokens_dict = Dictionary("encoded_decoder_vocabulary.txt") # vocab with chars (encoded subwords), 1024 chars in total, blank token not included here

sil_idx = 262 # ▁ token
blank_idx = 1024
lm = KenLM("kenlm.bin", tokens_dict)
transitions = np.zeros((tokens_dict.index_size() * tokens_dict.index_size()))

options = LexiconFreeDecoderOptions(
    beam_size=128,
    beam_size_token=1024,
    beam_threshold=500.0,
    lm_weight=1.0,
    sil_score=0.0,
    log_add=False,
    criterion_type=CriterionType.CTC
)

_decoder = LexiconFreeDecoder(options=options, 
                              lm=lm, 
                              sil_token_idx=sil_idx, 
                              blank_token_idx=blank_idx, 
                              transitions=transitions)

def decoder(log_probs, log_probs_len):
    # log_probs shape: batch, time, dictionary (1025 = 1024 tokens + blank)
    # log_probs_len: batch
    decode_result = []
    for i in range(log_probs.shape[0]):
        decode_beam = _decoder.decode(log_probs[i].numpy().ctypes.data, log_probs_len[i], log_probs.shape[-1])
        tokens = [t for t in decode_beam[0].tokens if t < 1024]
        decode_result.append(tokens)
    return decode_result

However, I am getting a minimal WER value of 11.02 with this code, while with the OpenSeq2Seq decoder, I am getting a WER value of 5.23. Can you please help me identify what I am doing wrong?

Also, is it okay to use a model trained without a sil token and use a zero score for it? Moreover, according to the documentation found here, the lexicon-free decoder should use the defined wordseparator or runs with --usewordpiece=true. However, I couldn't find such parameters in the Python bindings. Should I define this parameter in my case, or is there an alternative in the Python bindings?

Thank you.

Squire-tomsk commented 1 year ago

Question still actual ) @jacobkahn maybe you can help me with such migration?