k2-fsa / icefall

https://k2-fsa.github.io/icefall/
Apache License 2.0
884 stars 286 forks source link

Streaming `fast_beam_search` #1169

Open desh2608 opened 1 year ago

desh2608 commented 1 year ago

I am working on streaming, long-form decoding on TED-Lium. I have a causal Zipformer encoder, and have adapted the streaming_decode.py script to decode full recordings in a streaming fashion and also return token-level time-stamps.

So far, greedy_search and modified_beam_search are working fine. I am now trying to use fast_beam_search_nbest_LG, since I have previously got good gains with LG decoding for TED-Lium. However, the streaming version of this decoding method (and all fast_beam_search variants) are extremely slow. My guess is that this is caused by the design of the streaming algorithm --- after processing each chunk, it generates a lattice and calls one_best_decoding() on the lattice (line 291). The intermediate hypothesis is then updated with the result (line 295). This means that the lattice generation and shortest path would be called every time we process a chunk. I verified that as we process more chunks, the time taken in lattice generation and shortest path keeps increasing linearly.

In addition to being slow due to the recomputation, I think it is also a wrong way of streaming decoding, since we are modifying previously generated tokens based on future audio. In contrast, the streaming version of greedy_search is correct, in that it just keeps appending tokens to existing hypothesis.

pkufool commented 1 year ago

However, the streaming version of this decoding method (and all fast_beam_search variants) are extremely slow.

How slow? Do you have any number like decoding time or RTF.

My guess is that this is caused by the design of the streaming algorithm --- after processing each chunk, it generates a lattice and calls one_best_decoding() on the lattice (line 291).

This is decided by how often do you want the partial results, can be esaisy fixed by adding an option like return_partial, if True, running the one_best_decoding, if False do nothing.

In addition to being slow due to the recomputation, I think it is also a wrong way of streaming decoding, since we are modifying previously generated tokens based on future audio. In contrast, the streaming version of greedy_search is correct, in that it just keeps appending tokens to existing hypothesis.

Yes, that is a problem, in current implementation we will stack and unstack all the previous states & arcs, but actually only states and arcs in t-1 step are needed for current chunk. This is also happened on OnlineIntersectDensePruned (i.e. streaming ctc decoding), will try to fix it. One issue is the concatenated best path of chunks might not be the best path of the whole lattice.

desh2608 commented 1 year ago

However, the streaming version of this decoding method (and all fast_beam_search variants) are extremely slow.

How slow? Do you have any number like decoding time or RTF.

It takes ~492s to decode 1 recording from TED-Lium dev set (~8 mins long).

My guess is that this is caused by the design of the streaming algorithm --- after processing each chunk, it generates a lattice and calls one_best_decoding() on the lattice (line 291).

This is decided by how often do you want the partial results, can be esaisy fixed by adding an option like return_partial, if True, running the one_best_decoding, if False do nothing.

Possibly, but it is a little hard since the k2.RnntDecodingStreams is only created for the current chunk. I was trying to add an option to do lattice generation and one-best search only if at least one of the streams have finished processing, but got the following error at format_output.

[F] /exp/draj/jsalt2023/k2/k2/csrc/rnnt_decode.cu:690:void k2::rnnt_decoding::RnntDecodingStreams::GatherPrevFrames(const std::vector<int>&) Check failed: num_frames[i] <= static_cast<int32_t>(srcs_[i]->prev_frames.size()) (11632 vs. 16)

[ Stack-Trace: ]
/exp/draj/jsalt2023/k2/build_debug/lib/libk2_log.so(k2::internal::GetStackTrace()+0x46) [0x2aab3838ab88]
/exp/draj/jsalt2023/k2/build_debug/lib/libk2context.so(k2::internal::Logger::~Logger()+0x35) [0x2aab3242f2c5]
/exp/draj/jsalt2023/k2/build_debug/lib/libk2context.so(k2::rnnt_decoding::RnntDecodingStreams::GatherPrevFrames(std::vector<int, std::allocator<int> > const&)+0x47d) [0x2aab32657b51]
/exp/draj/jsalt2023/k2/build_debug/lib/libk2context.so(k2::rnnt_decoding::RnntDecodingStreams::FormatOutput(std::vector<int, std::allocator<int> > const&, bool, k2::Ragged<k2::Arc>*, k2::Array1<int>*, k2::Array1<int>*, k2::RaggedShape const&)+0x393) [0x2aab32658d3b]
/exp/draj/jsalt2023/k2/build_debug/lib/libk2context.so(k2::rnnt_decoding::RnntDecodingStreams::FormatOutput(std::vector<int, std::allocator<int> > const&, bool, k2::Ragged<k2::Arc>*, k2::Array1<int>*)+0x6d) [0x2aab32658061]
/exp/draj/jsalt2023/k2/build_debug/lib/_k2.cpython-38-x86_64-linux-gnu.so(+0x219e07) [0x2aab2d9bce07]
/exp/draj/jsalt2023/k2/build_debug/lib/_k2.cpython-38-x86_64-linux-gnu.so(+0x21eb33) [0x2aab2d9c1b33]
/exp/draj/jsalt2023/k2/build_debug/lib/_k2.cpython-38-x86_64-linux-gnu.so(+0x21e32c) [0x2aab2d9c132c]
/exp/draj/jsalt2023/k2/build_debug/lib/_k2.cpython-38-x86_64-linux-gnu.so(+0x21d9ef) [0x2aab2d9c09ef]
/exp/draj/jsalt2023/k2/build_debug/lib/_k2.cpython-38-x86_64-linux-gnu.so(+0x21da9f) [0x2aab2d9c0a9f]
/exp/draj/jsalt2023/k2/build_debug/lib/_k2.cpython-38-x86_64-linux-gnu.so(+0x9d5e9) [0x2aab2d8405e9]
python(PyCFunction_Call+0x52) [0x4f5652]
python(_PyObject_MakeTpCall+0x3bb) [0x4e0c8b]
python() [0x4f53fd]
python(_PyEval_EvalFrameDefault+0x49a9) [0x4dc999]
python(_PyEval_EvalCodeWithName+0x2f1) [0x4d6fb1]
python() [0x4f51bb]
python(_PyEval_EvalFrameDefault+0x1150) [0x4d9140]
python(_PyEval_EvalCodeWithName+0x2f1) [0x4d6fb1]
python(_PyFunction_Vectorcall+0x19c) [0x4e807c]
python(_PyEval_EvalFrameDefault+0x1150) [0x4d9140]
python(_PyEval_EvalCodeWithName+0x2f1) [0x4d6fb1]
python(_PyFunction_Vectorcall+0x19c) [0x4e807c]
python(_PyEval_EvalFrameDefault+0x1150) [0x4d9140]
python(_PyEval_EvalCodeWithName+0x2f1) [0x4d6fb1]
python(_PyFunction_Vectorcall+0x19c) [0x4e807c]
python(_PyEval_EvalFrameDefault+0x1150) [0x4d9140]
python(_PyFunction_Vectorcall+0x106) [0x4e7fe6]
python(PyObject_Call+0x24a) [0x4f768a]
python(_PyEval_EvalFrameDefault+0x1f7b) [0x4d9f6b]
python(_PyEval_EvalCodeWithName+0x2f1) [0x4d6fb1]
python(_PyFunction_Vectorcall+0x19c) [0x4e807c]
python(_PyEval_EvalFrameDefault+0x399) [0x4d8389]
python(_PyEval_EvalCodeWithName+0x2f1) [0x4d6fb1]
python(PyEval_EvalCodeEx+0x39) [0x585d79]
python(PyEval_EvalCode+0x1b) [0x585d3b]
python() [0x5a5a91]
python() [0x5a4a9f]
python() [0x45c417]
python(PyRun_SimpleFileExFlags+0x340) [0x45bfb8]
python() [0x44fd9e]
python(Py_BytesMain+0x39) [0x579dd9]
/lib64/libc.so.6(__libc_start_main+0xf5) [0x2aaaab616445]
python() [0x579c8d]

Traceback (most recent call last):
File "zipformer/streaming_decode.py", line 912, in <module>
main()
File "/home/hltcoe/draj/.conda/envs/jsalt/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "zipformer/streaming_decode.py", line 896, in main
results_dict = decode_dataset(
File "zipformer/streaming_decode.py", line 656, in decode_dataset
finished_streams = decode_one_chunk(
File "zipformer/streaming_decode.py", line 535, in decode_one_chunk
fast_beam_search_one_best(
File "/exp/draj/jsalt2023/icefall/egs/tedlium3/ASR/zipformer/streaming_beam_search.py", line 355, in fast_beam_search_one_best
lattice = decoding_streams.format_output(
File "/exp/draj/jsalt2023/k2/k2/python/k2/rnnt_decode.py", line 201, in format_output
ragged_arcs, out_map = self.streams.format_output(
RuntimeError:
Some bad things happened. Please read the above error messages and stack
trace. If you are using Python, the following command may be helpful:

gdb --args python /path/to/your/code.py

I suppose one way to solve this would be to create a global k2.RnntDecodingStreams object and pass it to decode each chunk, but that is a little hacky, and I was wondering if there is a cleaner way.

In addition to being slow due to the recomputation, I think it is also a wrong way of streaming decoding, since we are modifying previously generated tokens based on future audio. In contrast, the streaming version of greedy_search is correct, in that it just keeps appending tokens to existing hypothesis.

Yes, that is a problem, in current implementation we will stack and unstack all the previous states & arcs, but actually only states and arcs in t-1 step are needed for current chunk. This is also happened on OnlineIntersectDensePruned (i.e. streaming ctc decoding), will try to fix it. One issue is the concatenated best path of chunks might not be the best path of the whole lattice.

pkufool commented 1 year ago

I was trying to add an option to do lattice generation and one-best search only if at least one of the streams have finished processing.

Actually, it is almost impossible to implement this in the real scenario, because the streams to construct a k2.RnntDecodingStreams are dynamic, so each chunk might have a stream that terminates at current chunk.

pkufool commented 1 year ago

It takes ~492s to decode 1 recording from TED-Lium dev set (~8 mins long).

Oh, that is bad, we never tried such a long audio.

desh2608 commented 1 year ago

I was trying to add an option to do lattice generation and one-best search only if at least one of the streams have finished processing.

Actually, it is almost impossible to implement this in the real scenario, because the streams to construct a k2.RnntDecodingStreams are dynamic, so each chunk might have a stream that terminates at current chunk.

Do you have any suggestions on how to speed it up?

danpovey commented 1 year ago

If the stream gets too long, eventually shortest-path won't work well. I think the issue is we shouldn't be attempting to create a lattice from a super long stream. The question is at what stage of the workflow we should be segmenting, though. Perhaps Desh could try the method we were using in the example for text_align (the long_file_recog.sh) where we decode with overlapped chunks? Or we could come up with some other algorithm e.g. based on volume of an 0.2 sec sliding window, that creates chunks with some maximum duration, e.g. 30 seconds? [Would presumably have a similar interface to the segmentation part of that long_file_recog.sh script].

desh2608 commented 1 year ago

I have a long-form inference pipeline working well with overlapping chunks (using a non-causal zipformer). But I was also trying to do it in a streaming way so there is no need to create the chunks and merge tokens after. I imagined there would be a way to do this by spitting out tokens intermittently.

pkufool commented 1 year ago

I think we can use the endpointer in sherpa (fangjun implemented it), when hitting the endpointer we drop the history states and start from state 0.