TensorSpeech / TensorFlowASR

:zap: TensorFlowASR: Almost State-of-the-art Automatic Speech Recognition in Tensorflow 2. Supported languages that can use characters or subwords
https://huylenguyen.com/asr
Apache License 2.0
917 stars 242 forks source link

training problem with rnn_transducer #279

Open yiqiaoc11 opened 1 year ago

yiqiaoc11 commented 1 year ago

The transducer in TensorFlowASR\examples\rnn_transducer doesn't work for the current version with either current or pretrained config.yml. This is a fundamental function. Can the author or someone give it a try to validate it?

nglehuy commented 1 year ago

Hi @yiqiaoc11 I'm training on TPUs to validate this. Are you using the warp-transducer loss or the rnnt loss in tensorflow? So far as I'm testing with rnnt loss in tensorflow for the past months, it has some issues with convergence. But I dont have resources to test with GPUs.

yiqiaoc11 commented 1 year ago

@usimarit Thanks for comments. All my test were conducted on GPU as bs = 2. I tried once with TensorFlowASR\examples\rnn_transducer\config.yml with warp-transducer loss and observed same #231(https://github.com/TensorSpeech/TensorFlowASR/issues/231). Then I switched to rnnt-loss while using pretrained models' .yml containing warmup_steps for transformer scheduler. The loss is shown below, Epoch 1/20 14269/14269 [==============================] - 5923s 415ms/step - loss: 339.2118 - val_loss: 160.0162 Epoch 2/20 14269/14269 [==============================] - 5923s 415ms/step - loss: 254.4009 - val_loss: 147.7653 Epoch 3/20 14269/14269 [==============================] - 5923s 415ms/step - loss: 241.0356 - val_loss: 144.2561 Epoch 4/20 14269/14269 [==============================] - 5923s 415ms/step - loss: 231.6980 - val_loss: 140.2191 Epoch 5/20 14269/14269 [==============================] - 5923s 415ms/step - loss: 223.2308 - val_loss: 137.6041 Epoch 6/20 14269/14269 [==============================] - 5923s 415ms/step - loss: 216.7098 - val_loss: 136.0396

4/6-layer encoder worked with different warmup steps in case of rnnt-loss, but not 8. Just trying to recover the performance of the pretrained. Conformer reportedly works which differs only with rnn_transducer.

Feel to advise and I can try it on GPU here.

nglehuy commented 1 year ago

@yiqiaoc11 Could you help me train 2 models for 30 epochs using rnnt-loss:

  1. 4-layers encoder
  2. 8-layers encoder

Then plot the loss of 2 models for better comparison? Other configs are the same.

yiqiaoc11 commented 1 year ago

@usimarit, using the streaming config.yml (https://drive.google.com/file/d/1xYFYi3z94ZqaQZ-cTyiNekBwhITh1Ru4l) with warmup_steps=40000 , right?

From the timeline, you seemed to apply warp-transducer loss to get the pretrained .h5 weights.

nglehuy commented 1 year ago

@yiqiaoc11 Yes, with the pretrained config

I trained the rnn transducer on TPUs so warp-transducer loss cannot be applied, only rnnt-loss can be used here. But you can experiment with warp-transducer loss too, plotting the loss of 2 models for better comparison.

yiqiaoc11 commented 1 year ago

@usimarit, Now I'm having 2 x 3090, 2 x 30 epochs will take fairly long time with rnnt-loss. Now 8-layer doesn't converge and 4-layer converge with > 40000 warmup-steps. Conformer using the same rnnt-loss works. Could rnn_transducer differ while you pretrained it giving same loss, same optimizer, same number of weights?

nglehuy commented 1 year ago

@yiqiaoc11 The rnn_transducer structure stays the same in version v1.0.x Is the number of weights in your case the same as in the pretrained example?

yiqiaoc11 commented 1 year ago

Yes, the number of weights and distributions of layers are same, but other config information from the pretrained isn't tractable. Not sure what leads to the underfitting observed.

Primary loss curves for 4/8layer are posted for differentiation. Green curves are for 8-layer while blue 4-layer. Losses are very similar while models were tuned under the same .yml in GDrive. They don't converge. Untitled

[2023-02-09 09:03:10] PRINT Layer (type) Output Shape Param #
[2023-02-09 09:03:10] PRINT ==================================================================================================== [2023-02-09 09:03:10] PRINT streaming_transducer_encoder_reshape (Resha multiple 0
[2023-02-09 09:03:10] PRINT pe)
[2023-02-09 09:03:10] PRINT
[2023-02-09 09:03:10] PRINT streaming_transducer_encoder_block_0 (RnnTr multiple 5511488
[2023-02-09 09:03:10] PRINT ansducerBlock)
[2023-02-09 09:03:10] PRINT
[2023-02-09 09:03:10] PRINT streaming_transducer_encoder_block_1 (RnnTr multiple 7149888
[2023-02-09 09:03:10] PRINT ansducerBlock)
[2023-02-09 09:03:10] PRINT
[2023-02-09 09:03:10] PRINT streaming_transducer_encoder_block_2 (RnnTr multiple 5839168
[2023-02-09 09:03:10] PRINT ansducerBlock)
[2023-02-09 09:03:10] PRINT
[2023-02-09 09:03:10] PRINT streaming_transducer_encoder_block_3 (RnnTr multiple 5839168
[2023-02-09 09:03:10] PRINT ansducerBlock)
[2023-02-09 09:03:10] PRINT
[2023-02-09 09:03:10] PRINT ==================================================================================================== [2023-02-09 09:03:10] PRINT Total params: 24,339,712 [2023-02-09 09:03:10] PRINT Trainable params: 24,339,712 [2023-02-09 09:03:10] PRINT Non-trainable params: 0

[2023-02-09 09:03:15] PRINT Layer (type) Output Shape Param #
[2023-02-09 09:03:15] PRINT ==================================================================================================== [2023-02-09 09:03:15] PRINT streaming_transducer_encoder_reshape (Resha multiple 0
[2023-02-09 09:03:15] PRINT pe)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_0 (RnnTr multiple 5511488
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_1 (RnnTr multiple 7149888
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_2 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_3 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_4 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_5 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_6 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_7 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT ==================================================================================================== [2023-02-09 09:03:15] PRINT Total params: 47,696,384 [2023-02-09 09:03:15] PRINT Trainable params: 47,696,384 [2023-02-09 09:03:15] PRINT Non-trainable params: 0