Closed Anwarvic closed 2 years ago
You can find an example usage about unpruned + pruned loss in https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/model.py#L72
It uses a so-called "model warmup". The basic idea is that it first uses the unpruned version to make the model converge and applies the pruned version gradually.
Also, I was wondering what is the best way to understand the pruned loss other than reading the code?
We will post the paper soon.
Thanks for your quick response. Yeah, I can see that clearly in https://github.com/k2-fsa/icefall/blob/7100c33820c8c478e07d3435e25e4f1543b6eec7/egs/librispeech/ASR/pruned_transducer_stateless2/train.py#L557
But assuming I used the pruned_loss as shown here, model convergance is expected to be slow... right??
model convergance is expected to be slow... right
It depends on what optimizer you are using, and what model and what LR scheduler you are using.
The model, optimizer, and LR scheduler in icefall are specifically tuned for pruned RNN-T training.
Are you also including the simple_loss in your loss function? You need to include that so it trains the associated parameters and learns reasonable pruning bounds. As Fangjun mentioned, it makes sense to use only the simple loss during a warmup period.
Hey @danpovey @csukuangfj, thank you for your quick responses... You are my safety net here :smiley:
I've added the warmup functionality to my code the same way it was implemented here. Now, I'm gonna re-train the model and see if that improves the convergence. Will update this thread once it's done.
Also, I was wondering what is the best way to understand the pruned loss other than reading the code?
The paper is online now: https://arxiv.org/abs/2206.13236
Since the paper is out now, I'm gonna close this issue. Much appreciated!
Using my transducer model, I have tried both the pruned and the unpruned loss. The unpruned version worked pretty well, even outperforming the
torchaudio.rnnt_loss
. The problem is within the pruned version. The model is very slow to converge and the WER & CER are not improving knowing that I tried differentprune_range
values. Is this expected?Also, I was wondering what is the best way to understand the pruned loss other than reading the code?