A sausage with actual scores is created using DenseFsaVec, which holds the T * Q * V arc weights (T: input length; Q: number of context states; V: vocab size). The arc label is q * V + y for context state q and label y.
Then I intersect_dense the result of Step 1 with an Fsa that represents the n-gram context dependency with blank loops added (Q states; Q * V arcs), whose labels are also q * V + y. The shortest distance (get_tot_scores) gives us the denominator in the training loss.
Similarly, I intersect_dense the result of Step 1 with an Fsa that represents the intersection of the n-gram context dependency and the reference output with blank loops added. This gives us the numerator.
Unsurprisingly, the input to shortest distance in Step 2 is O(T * Q * V) and thus it's quite easy to run out of memory on a GPU with even 16GB of RAM. Do you have any suggestions on how I can reduce the memory usage?
Here's a Colab notebook detailing what I am doing. Please accept my apologies for not being able to share the notebook publicly due to company policies. Instead, please use the request access feature and kindly include your github user name in the message.
Hi k2 developers!
I am trying to implement a rudimentary version GNAT in k2, largely following the CTC implementation in Icefall:
DenseFsaVec
, which holds theT * Q * V
arc weights (T
: input length;Q
: number of context states;V
: vocab size). The arc label isq * V + y
for context stateq
and labely
.intersect_dense
the result of Step 1 with anFsa
that represents the n-gram context dependency with blank loops added (Q
states;Q * V
arcs), whose labels are alsoq * V + y
. The shortest distance (get_tot_scores
) gives us the denominator in the training loss.intersect_dense
the result of Step 1 with anFsa
that represents the intersection of the n-gram context dependency and the reference output with blank loops added. This gives us the numerator.Unsurprisingly, the input to shortest distance in Step 2 is
O(T * Q * V)
and thus it's quite easy to run out of memory on a GPU with even 16GB of RAM. Do you have any suggestions on how I can reduce the memory usage?Here's a Colab notebook detailing what I am doing. Please accept my apologies for not being able to share the notebook publicly due to company policies. Instead, please use the request access feature and kindly include your github user name in the message.
Thanks a lot for the help!