NVIDIA / NeMo

A scalable generative AI framework built for researchers and developers working on Large Language Models, Multimodal, and Speech AI (Automatic Speech Recognition and Text-to-Speech)
https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html
Apache License 2.0
12.14k stars 2.53k forks source link

beam search decoder and N-gram language model re-scoring with Citrinet #2142

Closed EmreOzkose closed 3 years ago

EmreOzkose commented 3 years ago

Hi, I am using Citrinet models to transcribe an audio (taken from Mini-Librispeech) and it transcribes good. However when I use beam search decoding in this tutorial, I am getting nonrelevant results. What can be the problem? Is there anyone to face with that problem?

Target sentence: THE MANAGER TOOK MY ADVICE BUT MARIA'S STILL MISSING

asr_model = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name='stt_en_citrinet_1024', strict=False)
transcript = asr_model.transcribe(paths2audio_files=[wav_path])[0]
print(f'Transcript: "{transcript}"')

the manager took my advice but maria still missing

asr_model = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name='stt_en_citrinet_1024', strict=False)
lm_path = 'lowercase_3-gram.pruned.1e-7.arpa'
beam_search_lm = nemo_asr.modules.BeamSearchDecoderWithLM(
    vocab=list(asr_model.decoder.vocabulary),
    beam_width=16,
    alpha=2, beta=1.5,
    lm_path=lm_path,
    num_cpus=max(os.cpu_count(), 1),
    input_tensor=False)

result_lm = beam_search_lm.forward(log_probs = np.expand_dims(probs, axis=0), log_probs_length=None)
print(result_lm)

[[(-163.07594299316406, 'a'), (-164.35794067382812, 'an'), (-171.96630859375, 'aa'), (-173.6032257080078, 'aan'), (-2158.844482421875, 'aavn'), (-2158.845458984375, 'aai'), (-2158.8759765625, 'aaan'), (-2159.080078125, 'aabn'), (-2159.123779296875, 'aani'), (-2159.14208984375, 'avn'), (-2159.169677734375, 'aav'), (-2159.1845703125, 'aaa'), (-2159.33349609375, 'aab'), (-2159.3623046875, 'abn'), (-2159.367431640625, 'aain'), (-2159.513427734375, 'aavni')]]

In adition to that, when I use QuartzNet15x5Base, I got good results.

asr_model = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name='QuartzNet15x5Base-En', strict=False)
lm_path = 'lowercase_3-gram.pruned.1e-7.arpa'
beam_search_lm = nemo_asr.modules.BeamSearchDecoderWithLM(
    vocab=list(asr_model.decoder.vocabulary),
    beam_width=16,
    alpha=2, beta=1.5,
    lm_path=lm_path,
    num_cpus=max(os.cpu_count(), 1),
    input_tensor=False)

result_lm = beam_search_lm.forward(log_probs = np.expand_dims(probs, axis=0), log_probs_length=None)
print(result_lm)

[[(-37.2982063293457, 'the manager took my advice but maria still missing'), (-43.66973114013672, 'the manager took my advice but marias still missing'), (-44.21354675292969, 'the manager took my advice but maria still missing '), (-45.12510299682617, "the manager took my advice but maria's still missing"), (-46.51353454589844, 'the manager took my advice but mara still missing'), (-48.59613800048828, 'the manager took my advice but maria still missin'), (-50.58517837524414, 'the manager took my advice but marias still missing '), (-52.988243103027344, 'the manager took my advice but maria till missing'), (-54.96765899658203, 'the manager took my advice but marias still missin'), (-55.6533317565918, 'the manager took my advice but maria still messing'), (-56.42308044433594, "the manager took my advice but maria's still missin"), (-57.81150817871094, 'the manager took my advice but mara still missin'), (-62.115047454833984, 'the manager took my advice but maria till missin'), (-2045.078857421875, 'the manager took my advice but maria still missina'), (-2047.1024169921875, 'the manager took my advice but maria still misin'), (-2048.393798828125, 'the manager took my advice but maria still missen')]]

kruthikakr commented 3 years ago

@EmreOzkose change. you have changed “Citrinet” in place of “quartznet”. you should change the class name which should be “nemo.asr.models.EncDecCTCModelBPE” at place of “nemo.asr.models.EncDecCTCModel”.

EmreOzkose commented 3 years ago

I changed the class name, but it didn't work.

titu1994 commented 3 years ago

The kenlm model being used is meant for Character encoding models - thats why it is emitting random outputs for Citrinet (which is a subword encoding model). Please follow the instructions here to build a custom KenLM language model for Subword encoding models - https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/asr_language_modeling.html

EmreOzkose commented 3 years ago

I am checking out, thank you

EmreOzkose commented 3 years ago

I trained a language model with train_kenlm.py (6-gram with 5M lines of text). Results of counting:

=== 1/5 Counting and sorting n-grams ===
Reading /path/to/scripts/trained_models/5M-6-gram.tmp.txt
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************
Unigram tokens 157070288 types 939
=== 2/5 Calculating and sorting adjusted counts ===
Chain sizes: 1:11268 2:1648550784 3:3091032832 4:4945652224 5:7212409856 6:9891304448
Substituting fallback discounts for order 0: D1=0.5 D2=1 D3+=1.5
Statistics:
1 939 D1=0.5 D2=1 D3+=1.5
2 471539 D1=0.443965 D2=1.02436 D3+=1.62906
3 10227983 D1=0.647774 D2=1.09039 D3+=1.4926
4 38579524 D1=0.775996 D2=1.14548 D3+=1.43008
5 70229201 D1=0.861431 D2=1.18701 D3+=1.38699
6 93502154 D1=0.862239 D2=1.22359 D3+=1.39601
Memory estimate for binary LM:
type      MB
probing 4340 assuming -p 1.5
probing 5024 assuming -r models -p 1.5
trie    1880 without quantization
trie     954 assuming -q 8 -b 8 quantization 
trie    1605 assuming -a 22 array pointer compression
trie     679 assuming -a 22 -q 8 -b 8 array pointer compression and quantization
=== 3/5 Calculating and sorting initial probabilities ===
Chain sizes: 1:11268 2:7544624 3:204559660 4:925908576 5:1966417628 6:2992068928
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
####################################################################################################
=== 4/5 Calculating and writing order-interpolated probabilities ===
Chain sizes: 1:11268 2:7544624 3:204559660 4:925908576 5:1966417628 6:2992068928
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
####################################################################################################
=== 5/5 Writing ARPA model ===
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************
Name:lmplz  VmPeak:26349464 kB  VmRSS:13084 kB  RSSMax:7157716 kB   user:187.9  sys:30.7643 CPU:218.664 real:432.785
[NeMo I 2021-05-03 10:52:48 train_kenlm:137] Running binary_build command 

    /path/to/kenlm/build/bin/lmplz -o 6 --text /path/to/scripts/trained_models/5M-6-gram.tmp.txt --arpa /path/to/scripts/trained_models/5M-6-gram.tmp.arpa --discount_fallback

Reading /path/to/scripts/trained_models/5M-6-gram.tmp.arpa
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************
Identifying n-grams omitted by SRI
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************
Writing trie
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************
SUCCESS

I didn't delete arpa file on purpose to investigate, hence there is not deletion logging in the end.

Language model is trained without any error. However there is still problem during inference. Citrinet outs:

asr_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name="stt_en_citrinet_256")
lm_path = "path/to/scripts/trained_models/5M-6-gram"

beam_search_lm = nemo_asr.modules.BeamSearchDecoderWithLM(
    vocab=list(asr_model.decoder.vocabulary),
    beam_width=16,
    alpha=1, beta=1.5,
    lm_path=lm_path,
    num_cpus=max(os.cpu_count(), 1),
    input_tensor=False)

wav_path = "/path/to/118-47824-0015.wav"
transcript = asr_model.transcribe(paths2audio_files=[wav_path])[0]
print(f'Transcript: "{transcript}"')
# output: the manager took my advice but maria still missing

logits = asr_model.transcribe([wav_path], logprobs=True)[0].cpu().numpy()
probs = softmax(logits)

beam_search_lm.forward(log_probs = np.expand_dims(probs, axis=0), log_probs_length=None)

gives:

[[(-174.5420379638672, 'i'), (-175.16610717773438, ''), (-175.2644805908203, 'fin'), (-175.37228393554688, 'f'), (-175.4651641845703, 'o'), (-175.60049438476562, 'w'), (-175.7156524658203, 'in'), (-175.9957733154297, 'fi'), (-176.3941650390625, 'wo'), (-176.45680236816406, 'fis'), (-176.63424682617188, 't'), (-176.68212890625, 'win'), (-176.70204162597656, 'iin'), (-176.73489379882812, 'wi'), (-177.02439880371094, 'io'), (-177.24179077148438, 'we')]]

In addition, QuartzNet15x5Base gives that:

[[(-471.08172607421875, 'themngertookymynvisesesutinmeristillmissing'), (-471.3054504394531, 'themngertookymynvisesesutimeristillmissing'), (-471.72296142578125, 'themngertookymyevisesesutinmeristillmissing'), (-471.86053466796875, 'themngertookymynvisesesuttmeristillmissing'), (-471.90106201171875, 'themngertookymynvisesesuttinmeristillmissing'), (-471.9837646484375, 'themngertookymynvisesesutinmristillmissing'), (-472.0075378417969, 'themngertookymynvisesesutimristillmissing'), (-474.77239990234375, 'themngertookymynvisesesutinmeristillmissin'), (-474.9961242675781, 'themngertookymynvisesesutimeristillmissin'), (-475.41363525390625, 'themngertookymyevisesesutinmeristillmissin'), (-475.55120849609375, 'themngertookymynvisesesuttmeristillmissin'), (-475.59173583984375, 'themngertookymynvisesesuttinmeristillmissin'), (-475.6744384765625, 'themngertookymynvisesesutinmristillmissin'), (-475.6982116699219, 'themngertookymynvisesesutimristillmissin'), (-477.19158935546875, 'themngertookymynvisesesutumeristillmissing'), (-477.216064453125, 'themngertookymyevisesesutimeristillmissing')]]

EmreOzkose commented 3 years ago

I tried to evaluate with eval_beamsearch_ngram.py and it works. It decreases word error rate. However ı don't understand why my first attempt doesn't work.

titu1994 commented 3 years ago

If you take a look into the script - we perform unicode ID encoding in the KenLM model (encoding subwords ids into a unicode ID that has been shifted by an offset of 100), therefore you would need to encore the emitted transcription of BPE models in the same way, then decode this shifted value again. That's about the major difference I can see c

titu1994 commented 3 years ago

Also you cannot use a KenLM trained on subwords for character based model decoding, which is why you observe random characters in QuartzNet output

EmreOzkose commented 3 years ago

When I open my 6-gram.tmp.txt file , I see that

¿ j g
Ȭ ɡ ʑ ‘ l ̮ ǩ h ȭ ̭
w ΋ l ‘ ý ë ą n õ å ® ¹ Ď
f £ s  s ‡ s ¼ e g ŏ ð e w Ɛ ̓ ǂ Ʌ „ ª o w Ö † ł ª q ķ υ k ˩ k ķ υ k ɓ
Ï ϒ ė Ŝ ‚ ƿ v Ù † Ļ Ģ  m ˘ l ü nj h ³ ž p Ʉ o Έ Ü Û e k Ă l ű ǜ â ē
² ‘ ý é ÷ ˄ ß e ǩ ¡ ­ Ǔ ț | q É υ ˩ ̶ ʁ ƨ ȸ
¬ ɡ t ĺ Ú ȇ Ę e § w Ç Ĺ e g ˩ ƨ ȸ ¤ x Œ ƒ ¢ è Ç ñ ´ f ͱ
‰ ½ í p ɓ υ ȸ ̶ ™ – ž e • ë Ɵ
 ơ ē l ʘ • k h „ ˄ Ȼ ć ë ʼ e œ h ĭ g ç ǖ Í Ŋ Ô š o · h î q Þ ̖ g m ¡ p Ð þ • Š ū p ě ǿ Ó Š f ¡ e ‡ ş à j e Ķ „ | g Ǖ ͔
f q Î Ė g  ̷ o € ƨ   ƪ
...

These are encoded IDs, right? Because there should be some subwords here such as -est, -er, -long, etc.. What you made is encoding these subwords (est, er, etc..) to above characters? If these are encoded subwords, how the model decode these encoded subwords to target words?

I can also provide my arpa file:

\1-grams:
-5.371167   <unk>   0
0   <s> -2.9420516
-2.602343   </s>    0
-2.8196058  ¿   -1.1019585
-2.8083456  j   -0.5263033
-2.988827   g   -1.2336508
-2.9303837  Ȭ   -0.82106936
-2.8464954  ɡ   -0.7884831
-2.9608734  ʑ   -0.8513431
-2.7415118      -1.2384112
-2.6194754  l   -1.8475825
-3.003509       -0.70200557
-2.8075523  ǩ   -0.79507536
-2.621536   h   -1.8086596
-2.9552693  ȭ   -0.77616245
-2.9421077      -0.8058051
-2.639997   w   -1.6862018
-2.9620028  ΋   -0.6901992
-3.1846087  ý   -0.8091599
-2.762429   ë   -1.3114026
...

As I said, I was expecting subwords instead of characters like ¿, ǩ, etc..

titu1994 commented 3 years ago

Yes precisely - we take the multi-unicode subword token and map it to a singular unicode ID (shifted by some constant). That is why the text is not correctly displayed, since the tokens are encoded. You can read it line by line, shift them back and use that id in a reverse lookup of the Tokenizer to get back the true subword text.

titu1994 commented 3 years ago

The decoding step of the code snippet provided explains all of this, and the output of the model is just the direct subword id without any such encoding. So you can just do a reverse lookup with the tokekizer to get back the token.

We do all this so as to reduce the size of the LM binary.

EmreOzkose commented 3 years ago

Thank you so much! It is very reasonable. I understand except one issue. How do you convert multi-unicode subword to singular unicode ID?

titu1994 commented 3 years ago

Refer to the eval script provided here - https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/asr_language_modeling.html

Dreahim commented 9 months ago

Hello, I have the same problem Could you solve it ? and how did you do it ? you and Mr, titu are speaking about decoding Unicode ID, but I didn't git idea, So, I want you to tell me how to solve it.