k2-fsa / fast_rnnt

A torch implementation of a recursion which turns out to be useful for RNN-T.
Other
139 stars 22 forks source link

AssertionError: assert py.is_contiguous() #14

Closed Anwarvic closed 2 years ago

Anwarvic commented 2 years ago

I'm working on integrating FastRNNT with Speechbrain, check this Pull Request.

At the current moment, I'm trying to train a transducer model on the multilingual TEDx dataset (mTEDx) for French. Whenever I train my model, I get this assertion error (he issue's title). However, it says in the mutual_information.py file that:

# The following assertions are for efficiency
assert px.is_contiguous()
assert py.is_contiguous()

Once I comment these two lines, everything works just fine. Using a transducer model with an encoder of wav2vec2 pre-trained model + one linear layer, and a one layer GRU as a decoder, the model trains just fine and I got 14.37 WER on the French test set which is way better than our baseline.

Now, I have these two questions:

Your guidance is much appreciated!

csukuangfj commented 2 years ago

If I remember correctly, the cpp code is using tensor accessor to access the data, which does not require a contiguous tensor.

But a contiguous tensor is more cache friendly, so I suggest changing it to

px = px.contiguous()

Anwarvic commented 2 years ago

So, theoretically commenting these two assertions won't affect the performance... right? And changing the tensors to contiguous will just help a little bit with memory?

danpovey commented 2 years ago

It says right there, it's for efficiency, so yes, using non-contiguous tensors will affect the performance. Making that copy may not necessarily require more memory, it depends whether the original (before the copy) is required for backprop. I suggest to try adding the .contiguous() statement before the log_softmax, if possible, since likely the log_softmax needs the output of its operation for the backprop (but not the input), so the copy prior to the .contiguous() before the log_softmax likely would not be held for backprop.

Anwarvic commented 2 years ago

@danpovey I'm sorry I didn't get what you mean by "adding the .contiguous() statement before the log_softmax".

By ".contiguous() statement", you meant px = px.contiguous() & py = py.contiguous().. right?

Also, which log_softmax are we talking about here exactly? The one at the end of the jointer network?

danpovey commented 2 years ago

At some point in the RNN-T computation there is a normalization of log-probs, probably via log_softmax(). I meant doing it just before then. But this is probably not super critical as I think this is not going to dominate memory requirements anyway; thanks to using pruned RNN-T, we are not instantiating any really huge tensors. So you can do it to the px and py, I think, if they are not naturally contiguous.

Anwarvic commented 2 years ago

I have added the following two lines just before this part in the mutual_information.py script:

if not px.is_contiguous(): px = px.contiguous()
if not py.is_contiguous(): py = py.contiguous()

@danpovey If you agree with what I did, feel free to close this issue!

csukuangfj commented 2 years ago

I think you don't need to check whether it is contiguous.

px.contiguous() is a no-op if px is already contiguous, I think.

Anwarvic commented 2 years ago

Thanks for the help!

pkufool commented 2 years ago

@Anwarvic Where do you add this line, I think there is px.contiguous in rnnt_loss.py https://github.com/danpovey/fast_rnnt/blob/c268c3d5a005968b87a724a21082410a3ec0bac3/fast_rnnt/python/fast_rnnt/rnnt_loss.py#L810-L811

pkufool commented 2 years ago

Ok, I think I forgot get_rnnt_logprobs and get_rnnt_logprobs_smoothed.

Anwarvic commented 2 years ago

My issue was in the AssertionError which only exists in the mutual_information.py script... I think.

pkufool commented 2 years ago

My issue was in the AssertionError which only exists in the mutual_information.py script... I think.

Yes, I meaned we won't call mutual_information_recursion directly, we call it from functions in rnnt_loss.py. Anyway, fix it in mutual_information.py is OK. Thanks!