k2-fsa / k2

FSA/FST algorithms, differentiable, with PyTorch compatibility.
https://k2-fsa.github.io/k2
Apache License 2.0
1.11k stars 213 forks source link

Techniques for reducing memory usage for large lattices such as GNAT #1169

Open kho opened 1 year ago

kho commented 1 year ago

Hi k2 developers!

I am trying to implement a rudimentary version GNAT in k2, largely following the CTC implementation in Icefall:

  1. A sausage with actual scores is created using DenseFsaVec, which holds the T * Q * V arc weights (T: input length; Q: number of context states; V: vocab size). The arc label is q * V + y for context state q and label y.
  2. Then I intersect_dense the result of Step 1 with an Fsa that represents the n-gram context dependency with blank loops added (Q states; Q * V arcs), whose labels are also q * V + y. The shortest distance (get_tot_scores) gives us the denominator in the training loss.
  3. Similarly, I intersect_dense the result of Step 1 with an Fsa that represents the intersection of the n-gram context dependency and the reference output with blank loops added. This gives us the numerator.

Unsurprisingly, the input to shortest distance in Step 2 is O(T * Q * V) and thus it's quite easy to run out of memory on a GPU with even 16GB of RAM. Do you have any suggestions on how I can reduce the memory usage?

Here's a Colab notebook detailing what I am doing. Please accept my apologies for not being able to share the notebook publicly due to company policies. Instead, please use the request access feature and kindly include your github user name in the message.

Thanks a lot for the help!

csukuangfj commented 1 year ago

Could you have a look at

https://github.com/k2-fsa/icefall/blob/master/icefall/mmi.py

It uses pruned intersect.