k2-fsa / k2

FSA/FST algorithms, differentiable, with PyTorch compatibility.
https://k2-fsa.github.io/k2
Apache License 2.0
1.12k stars 213 forks source link

Can't use k2.intersect_dense_pruned with topk nn_output #991

Closed yuekaizhang closed 2 years ago

yuekaizhang commented 2 years ago

For example, there is a ctc_log_probs matrix with shape (1,2,5) blank symbol + 4 non-blank token

It works if I construct a DenseFSAVec and do intersect_dense_pruned with its corrosponding H or HLG graph. autograd_log ) However, if I would like construct a DenseFSAVec from this ctc_log_probs's topk, like below, when topk is 2 autograd_log2

It fails to do intersect_dense_pruned with this topk DenseFSAVec.

For my real case, dict size is 4233, if decoding using CTC-prefix-beam-search, choose topk=20 ctc_log_probs, could speed up the decoding 100x times on cpu without accuracy loss.

Back to WFST HLG decode, I was wondering if I could do the similar thing and expect a similar improvement. I have already tuned the decoding params: max_active_states:7000, min_active_states:0, search_beam:15, output_beam:8 with one_best_decoding. However, the decoding speed of HLG decode on gpu is not very good comparing with ctc_prefix_beam_search on cpu. Using the same arpa file, decoding 1024 5s audios at once, ctc_prefix_beam_search would cost around 1s using four cpu threads. HLG decode on gpu needs around 0.5s if one_best_decoding, cost 1s if generating nbest hypos.

yuekaizhang commented 2 years ago

"It fails to do intersect_dense_pruned with this topk DenseFSAVec." It's not real topk, just ctc_log_probs[:,:,:topk]. Since current can't construct DenseFSAVec with topk ctc_log_probs and corresponding idx.

danpovey commented 2 years ago

We don't have a special kind of FSA that has the structure you want, i.e. with a fixed number of arcs from each state with provided symbols. I doubt it would give very much improvement in speed even if you did, though.

yuekaizhang commented 2 years ago

We don't have a special kind of FSA that has the structure you want, i.e. with a fixed number of arcs from each state with provided symbols. I doubt it would give very much improvement in speed even if you did, though.

Thanks Dan, then I would like turn to other directions e.g. blank skipping to see if it helps.