k2-fsa / icefall

https://k2-fsa.github.io/icefall/
Apache License 2.0
792 stars 267 forks source link

Decoding Issue: fast beam search nbest LG #1591

Open divyeshrajpura4114 opened 1 month ago

divyeshrajpura4114 commented 1 month ago

I am trying to decode model trained with receipe pruned_transducer_stateless7_streaming. I am able to successfully decode with fast beam search (without LG), however, when I try to decode with LG it throws below error.

2024-04-12 07:23:45,958 INFO [zipformer.py:405] At encoder stack 4, which has downsampling_factor=2, we will combine the outputs of layers 1 and 3, with downsampling_factors=2 and 8.                      
2024-04-12 07:23:46,039 INFO [decode.py:873] Calculating the averaged model over epoch range from 9 (excluded) to 10                                                                                        
2024-04-12 07:23:55,740 INFO [lexicon.py:168] Loading pre-compiled k2_model/exp/lang/bpe_5000/Linv.pt                                                                                                  
2024-04-12 07:24:00,599 INFO [decode.py:942] Loading k2_model/exp/lang/bpe_5000/LG.pt                                                                                                                  
2024-04-12 07:24:01,728 INFO [decode.py:955] Number of model parameters: 79022891                                                                                                                          
2024-04-12 07:24:01,729 INFO [asr_datamodule.py:422] About to get test cuts                                                                                                                                
[F] /var/www/k2/csrc/top_sort.cu:324:k2::FsaVec k2::TopSorter::TopSort(k2::Array1<int>*) Check failed: start_state_present[0] == 1 (0 vs. 1) Our current implementation requires that the start state in eac
h Fsa must be present in the first batch                                                                                                                                                                    

[ Stack-Trace: ]                                                                                                                                                                                            
/opt/conda/lib/python3.10/site-packages/k2/lib64/libk2_log.so(k2::internal::GetStackTrace()+0x34) [0x7f426ef829b4]                                                                                          
/opt/conda/lib/python3.10/site-packages/k2/lib64/libk2context.so(k2::internal::Logger::~Logger()+0x2a) [0x7f426f47ed4a]                                                                                    
/opt/conda/lib/python3.10/site-packages/k2/lib64/libk2context.so(k2::TopSorter::TopSort(k2::Array1<int>*)+0x452) [0x7f426f8019d2]                                                                          
/opt/conda/lib/python3.10/site-packages/k2/lib64/libk2context.so(k2::TopSort(k2::Ragged<k2::Arc>&, k2::Ragged<k2::Arc>*, k2::Array1<int>*)+0x134) [0x7f426f7f49b4]                                          
/opt/conda/lib/python3.10/site-packages/_[k2.cpython-310-x86_64-linux-gnu.so](http://k2.cpython-310-x86_64-linux-gnu.so/)(+0x829af) [0x7f4274ebf9af]                                                                                                      
/opt/conda/lib/python3.10/site-packages/_[k2.cpython-310-x86_64-linux-gnu.so](http://k2.cpython-310-x86_64-linux-gnu.so/)(+0x3dda7) [0x7f4274e7ada7]                                                                                                      
python3() [0x4fc697]                                                                                                                                                                                        
python3(_PyObject_MakeTpCall+0x25b) [0x4f614b]                                                                                                                                                              
python3(_PyEval_EvalFrameDefault+0x5757) [0x4f26f7]                                                                                                                                                        
python3(_PyFunction_Vectorcall+0x6f) [0x4fcadf]                                                                                                                                                            
python3(_PyEval_EvalFrameDefault+0x4b26) [0x4f1ac6]                                                                                                                                                        
python3(_PyFunction_Vectorcall+0x6f) [0x4fcadf]                                                                                                                                                            
python3(_PyEval_EvalFrameDefault+0x13b3) [0x4ee353]
python3(_PyFunction_Vectorcall+0x6f) [0x4fcadf]
python3(_PyEval_EvalFrameDefault+0x13b3) [0x4ee353]
python3(_PyFunction_Vectorcall+0x6f) [0x4fcadf]
python3(_PyEval_EvalFrameDefault+0x13b3) [0x4ee353]
python3(_PyFunction_Vectorcall+0x6f) [0x4fcadf]
python3(_PyEval_EvalFrameDefault+0x2b79) [0x4efb19]
python3(_PyFunction_Vectorcall+0x6f) [0x4fcadf]
python3(_PyEval_EvalFrameDefault+0x31f) [0x4ed2bf]
python3() [0x591d92]
python3(PyEval_EvalCode+0x87) [0x591cd7]
python3() [0x5c2967]
python3() [0x5bdad0]
python3() [0x45956b]
python3(_PyRun_SimpleFileObject+0x19f) [0x5b805f]
python3(_PyRun_AnyFileObject+0x43) [0x5b7dc3]
python3(Py_RunMain+0x38d) [0x5b4b7d]
python3(Py_BytesMain+0x39) [0x584e49]
/usr/lib/x86_64-linux-gnu/libc.so.6(+0x29d90) [0x7f4329dd6d90]
/usr/lib/x86_64-linux-gnu/libc.so.6(__libc_start_main+0x80) [0x7f4329dd6e40]
python3() [0x584cfe]

Traceback (most recent call last):
  File "/workspace/icefall/egs/personal/pruned_transducer_stateless7_streaming/decode.py", line 993, in <module>
    main()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/icefall/egs/personal/pruned_transducer_stateless7_streaming/decode.py", line 970, in main
    results_dict = decode_dataset(
  File "/workspace/icefall/egs/personal/pruned_transducer_stateless7_streaming/decode.py", line 651, in decode_dataset
    hyps_dict = decode_one_batch(
  File "/workspace/icefall/egs/personal/pruned_transducer_stateless7_streaming/decode.py", line 442, in decode_one_batch
    hyp_tokens = fast_beam_search_nbest_LG(
  File "/workspace/icefall/egs/personal/pruned_transducer_stateless7_streaming/beam_search.py", line 223, in fast_beam_search_nbest_LG
    path_lattice = k2.top_sort(k2.connect(path_lattice))
  File "/opt/conda/lib/python3.10/site-packages/k2/fsa_algo.py", line 244, in top_sort
    ragged_arc, arc_map = _k2.top_sort(fsa.arcs, need_arc_map=need_arc_map)

Library versions:

torch: 2.2.1
k2: 1.24.4
icefall: pulled on 18th Mar 2024

Thanks Divyesh Rajpura

divyeshrajpura4114 commented 1 month ago

When I changed the device from GPU to CPU, its working fine.

I am using cuda 12.3 and as described in k2-1204, top_sort.cu is not working for cuda>12.

Is this correct?