1ytic / warp-rnnt

CUDA-Warp RNN-Transducer
MIT License
211 stars 41 forks source link

Normalize the RNN-T Loss with input seq length #30

Closed aheba closed 2 years ago

aheba commented 2 years ago

Hello,

We saw that your implementation doesn't normalize the loss with the input seq length, Here is an example of the training on TIMIT corpus:

RNNT loss torchaudio:

epoch: 1, lr_adam: 3.00e-04, lr_wav2vec: 1.00e-04 - train loss: 1.40e+02 - valid loss: 1.16e+02, valid PER: 1.00e+02
epoch: 2, lr_adam: 3.00e-04, lr_wav2vec: 1.00e-04 - train loss: 95.39 - valid loss: 64.14, valid PER: 91.21
epoch: 3, lr_adam: 3.00e-04, lr_wav2vec: 1.00e-04 - train loss: 35.57 - valid loss: 17.67, valid PER: 22.56
epoch: 4, lr_adam: 3.00e-04, lr_wav2vec: 1.00e-04 - train loss: 19.28 - valid loss: 12.31, valid PER: 16.15

RNNT loss spbrain:

epoch: 1, lr_adam: 3.00e-04, lr_wav2vec: 1.00e-04 - train loss: 1.06 - valid loss: 7.76e-01, valid PER: 1.00e+02
epoch: 2, lr_adam: 3.00e-04, lr_wav2vec: 1.00e-04 - train loss: 6.28e-01 - valid loss: 2.57e-01, valid PER: 54.77
epoch: 3, lr_adam: 3.00e-04, lr_wav2vec: 1.00e-04 - train loss: 2.16e-01 - valid loss: 1.08e-01, valid PER: 23.30
epoch: 4, lr_adam: 3.00e-04, lr_wav2vec: 1.00e-04 - train loss: 1.27e-01 - valid loss: 8.16e-02, valid PER: 14.56
1ytic commented 2 years ago

AFAIK, normalizing to length doesn't make sense. RNN-T loss treats a single sequence as a single input object. In practice, it just scale your gradient values. In order to compare the normalized loss you need to scale the learning rate as well. Finally, you can specify this parameter if you like.