k2-fsa / fast_rnnt

A torch implementation of a recursion which turns out to be useful for RNN-T.
Other
139 stars 22 forks source link

Trying to Understand pruned_loss #8

Closed Anwarvic closed 2 years ago

Anwarvic commented 2 years ago

Using my transducer model, I have tried both the pruned and the unpruned loss. The unpruned version worked pretty well, even outperforming the torchaudio.rnnt_loss. The problem is within the pruned version. The model is very slow to converge and the WER & CER are not improving knowing that I tried different prune_range values. Is this expected?

Also, I was wondering what is the best way to understand the pruned loss other than reading the code?

csukuangfj commented 2 years ago

You can find an example usage about unpruned + pruned loss in https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/model.py#L72

It uses a so-called "model warmup". The basic idea is that it first uses the unpruned version to make the model converge and applies the pruned version gradually.


Also, I was wondering what is the best way to understand the pruned loss other than reading the code?

We will post the paper soon.

Anwarvic commented 2 years ago

Thanks for your quick response. Yeah, I can see that clearly in https://github.com/k2-fsa/icefall/blob/7100c33820c8c478e07d3435e25e4f1543b6eec7/egs/librispeech/ASR/pruned_transducer_stateless2/train.py#L557

But assuming I used the pruned_loss as shown here, model convergance is expected to be slow... right??

csukuangfj commented 2 years ago

model convergance is expected to be slow... right

It depends on what optimizer you are using, and what model and what LR scheduler you are using.

The model, optimizer, and LR scheduler in icefall are specifically tuned for pruned RNN-T training.

danpovey commented 2 years ago

Are you also including the simple_loss in your loss function? You need to include that so it trains the associated parameters and learns reasonable pruning bounds. As Fangjun mentioned, it makes sense to use only the simple loss during a warmup period.

Anwarvic commented 2 years ago

Hey @danpovey @csukuangfj, thank you for your quick responses... You are my safety net here :smiley:

I've added the warmup functionality to my code the same way it was implemented here. Now, I'm gonna re-train the model and see if that improves the convergence. Will update this thread once it's done.

csukuangfj commented 2 years ago

Also, I was wondering what is the best way to understand the pruned loss other than reading the code?

The paper is online now: https://arxiv.org/abs/2206.13236

Screen Shot 2022-07-03 at 10 17 46

Anwarvic commented 2 years ago

Since the paper is out now, I'm gonna close this issue. Much appreciated!