k2-fsa / k2

FSA/FST algorithms, differentiable, with PyTorch compatibility.
https://k2-fsa.github.io/k2
Apache License 2.0
1.11k stars 213 forks source link

How to use large-scale language models in k2 #1010

Open wwx007121 opened 2 years ago

wwx007121 commented 2 years ago

I would like to separate the acoustic task from the text task more significantly.My solution is an acoustic task where the encoder trains the modeling unit as a phoneme,then use the transformer or transducer in decoder. But an exception occurs as HLG being built from 3-gram.arpa. Before this experiment , it has been successful that HLG was built from 3-gram.pruned.3e-7.arpa && 3-gram.pruned.1e-7.arpa,and is also used in decoding process @danpovey LOG Detail :

2022-07-12 17:33:26,070 INFO [compile_hlg.py:203] Processing data/lang_phone 2022-07-12 17:33:26,351 INFO [compile_hlg.py:80] Loading pre-compiled data/lang_phone/Linv.pt 2022-07-12 17:33:26,447 INFO [compile_hlg.py:122] Building ctc_topo. max_token_id: 71 2022-07-12 17:33:26,542 INFO [compile_hlg.py:131] Loading G_3_gram.fst.txt <class '_k2.RaggedArc'> 2022-07-12 17:36:36,043 INFO [compile_hlg.py:142] Intersecting L and G 2022-07-12 17:49:59,343 INFO [compile_hlg.py:144] LG shape: (306009979, None) 2022-07-12 17:49:59,343 INFO [compile_hlg.py:146] Connecting LG 2022-07-12 17:49:59,343 INFO [compile_hlg.py:148] LG shape after k2.connect: (306009979, None) 2022-07-12 17:49:59,343 INFO [compile_hlg.py:150] <class 'torch.Tensor'> 2022-07-12 17:49:59,343 INFO [compile_hlg.py:151] Determinizing LG 2022-07-12 18:11:49,404 INFO [compile_hlg.py:154] <class '_k2.ragged.RaggedTensor'> 2022-07-12 18:11:49,405 INFO [compile_hlg.py:156] Connecting LG after k2.determinize 2022-07-12 18:11:49,405 INFO [compile_hlg.py:159] Removing disambiguation symbols on LG [F] /k2/build/temp.linux-x86_64-3.8/k2/csrc/array_ops.cc:279:void k2::RowSplitsToRowIds(const k2::Array1&, k2::Array1) Check failed: num_elems == row_splits[num_rows] (1219275412 vs. 209651427)

Python Traceback (most recent call last): File "./local/compile_hlg.py", line 217, in main() File "./local/compile_hlg.py", line 205, in main HLG = compile_HLG(lang_dir) File "./local/compile_hlg.py", line 169, in compile_HLG LG = k2.remove_epsilon(LG) File "*/anaconda3/lib/python3.8/site-packages/k2-1.15.1.dev20220617+cpu.torch1.8.1-py3.8-linux-x86_64.egg/k2/fsa_algo.py", line 605, in remove_epsilon ragged_arc, arc_map = _k2.remove_epsilon(fsa.arcs, fsa.properties)

k2 version: k2-1.15.1.dev20220617+cpu.torch1.8.1-py3.8-linux-x86_64.egg

wwx007121 commented 2 years ago

And also, the memory usage is too much during this process, about 80G, but the file 3-gram.arpa is only 2.3G

wwx007121 commented 2 years ago

One possibility is that the number of states during the calculation process exceeds the unsigned int type

wwx007121 commented 2 years ago

in issue: https://github.com/k2-fsa/icefall/issues/132 An idea provided here is to build an HLG with a small-scale language model, and then use a large-scale language model to rescore. I do an experiment to compare several decode process.The result is as follows:

dataset:librispeech encoder: conformer(12layers) decoder: transformer (3layers) model sizes 27M HLG: 3gram.3e-7, G: 3gram.3e-7 best wer: ngram_lm_scale_0.01_attention_scale_8.0.result.cer:wer 4.81, ngram_lm_scale_1.5_attention_scale_0.6.result.cer:wer 4.63 HLG: 3gram.1e-7, G: 3gram.1e-7 best wer: ngram_lm_scale_0.01_attention_scale_10.0.result.cer:wer 4.52, ngram_lm_scale_5.0_attention_scale_5.0.result.cer:wer 4.29 HLG: 3gram.3e-7, G: 3gram.1e-7 best wer: ngram_lm_scale_0.01_attention_scale_10.0.result.cer:wer 5.04 , ngram_lm_scale_1.5_attention_scale_0.6.result.cer:wer 4.77

This shows that hybrid language decoding is even worse than decoding with only small language models

large language mode loading code of mine:

lexicon = Lexicon("data/lang_phone/")
first_word_disambig_id = lexicon.word_table["#0"]
with open(os.path.join("data/lang_phone/", "G_3_gram.large.fst.txt")) as f:
    G = k2.Fsa.from_openfst(f.read(), acceptor=False)
    del G.aux_labels
    G.labels[G.labels >= first_word_disambig_id] = 0
    G.__dict__["_properties"] = None
    G = k2.Fsa.from_fsas([G])
    G = k2.arc_sort(G)
    G = k2.add_epsilon_self_loops(G)
    G = k2.arc_sort(G)
    G.lm_scores = G.scores.clone()

generate lattice code

    lattice = get_lattice(
            nnet_output=netout_local,
            decoding_graph=HLG,
            supervision_segments=supervision_segments,
            search_beam=search_beam,
            output_beam=output_beam,
            min_active_states=min_active_states,
            max_active_states=max_active_states,
            subsampling_factor=subsampling_factor,
    )

    lattice = rescore_with_whole_lattice(
        lattice=lattice,
        G_with_epsilon_loops= G
    )
duj12 commented 1 year ago

Meet the same question. Hope hlg can be smaller after compiled.