NVIDIA / NeMo

A scalable generative AI framework built for researchers and developers working on Large Language Models, Multimodal, and Speech AI (Automatic Speech Recognition and Text-to-Speech)
https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html
Apache License 2.0
12.01k stars 2.5k forks source link

[QUESTION] Best practices for RNN-t train\finetune #8330

Closed kokamido closed 7 months ago

kokamido commented 9 months ago

Hi! I would like to use an RNN-transducer for a Russian ASR task. I use a config like conformer_transducer_bpe.yaml, but a few things are unclear to me.

  1. Should I use RNN-t instead of CTC + greedy decoding in non-streaming setup? Does RNN-t perform significantly better than CTC + beam search + domain kenlm?
  2. I compared validation WERs for CTC and RNN-t setups from catalog.ngc.nvidia.com. There are STT En Conformer-Transducer Large and STT En Conformer-Transducer Large for example. I understand that audio-encoders of these models are not equal but they are similar. I conclude that the expected WER improvement from CTC -> RNNT switch is ~10%. My question is about this estimation. I already have ~12WER and ~15WER on my domain data (phone calls and online meeting respectively). According to your experience\intuition, how likely is to get an 1.2-1.5 WER improvement from RNN-t setup comparing to CTC+greedy decoding with the equal audio encoder?
  3. I try to use my previously trained CTC-conformer as an frozen audio encoder. I train only a linear adapter, a prediction network and a joint layer. I use my dataset (~600 hours of speech, 300h of phone calls and 300h of online meetings) and parameters from NeMO base rnnt config (excluding encoder parameters). It converges but its WER is barely better than WER of my CTC model. Difference is ~0.2 WER. Did you try to train an RNN-t model with encoder frozen? Does RNN-t converges better if encoder is trainable?
  4. I heard that CTC head of hybrid model may perform better than pure CTC model trained on the same data. Have you noticed this effect?
  5. Should a prediction network act like a "normal" language model? I mean, if I use default Nemo RNN-t prednet
    prednet:
      pred_hidden: ${model.model_defaults.pred_hidden}
      pred_rnn_layers: 1
      t_max: null
      dropout: 0.2

    can I use its performance on the next token prediction task as a proxy metric of its impact on performance of the whole RNN-t model? Should I consider different metrics to determine if performance of my RNN-t model is negatively affected by "bad" prednet but not by audio encoder?

nithinraok commented 8 months ago

Both RNNTs and CTCs in general has Pros and Cons. RNNTs converge faster compared to CTCs, due to implicit LM prediction nature of RNNTs can get better WER just with greedy decoding, however RNNTs are ~3x slower in training/inference compared to CTCs, but can be improved with a TDT type decoder + CUDA graphs improvements.

Should I use RNN-t instead of CTC + greedy decoding in non-streaming setup? Does RNN-t perform significantly better than CTC + beam search + domain kenlm?

If inference speed is not very critical, consider RNNT.

how likely is to get an 1.2-1.5 WER improvement from RNN-t setup comparing to CTC+greedy decoding with the equal audio encoder?

Those CTC and RNNT models are trained on 25k hrs of data, so when moving to different data you might not see same improvements. However finetuning with an RNNT decoder is always recommended as its converges faster compared to CTC models. Use a checkpoint that is trained on large amount of data even if its not the same language.

Did you try to train an RNN-t model with encoder frozen? Does RNN-t converges better if encoder is trainable?

Generally our training recipe involves finetuning an encoder with very small lr for few thousands of steps. With this scenario on multiple languages, finetuning from En ASR encoder, RNNTs showed better performance. Unfortunately we cannot comment on exact WER improvement as its type of data, language and hp dependent.

I heard that CTC head of hybrid model may perform better than pure CTC model trained on the same data. Have you noticed this effect?

Yes CTC with a hybrid RNNT-CTC performs better and also converges faster than just CTC training. Also you may alternatively try finetuning CTC from an RNNT encoder as well.

Should I consider different metrics to determine if performance of my RNN-t model is negatively affected by "bad" prednet but not by audio encoder?

Very good question, @titu1994 / @VahidooX do you know? Also pls add if I am missing any points.

kokamido commented 8 months ago

@titu1994 , @VahidooX , could you share your views and experiences about this? Do you know the way to align an impact of the predent on the RNN-t's quality with "usual" LM metrics of the predent? Maybe I can just use prednet's quality on the next token prediction task as a proxy metric?

titu1994 commented 8 months ago

I think nithin covered it pretty well.

Using prednet as an lm doesn't work unless you are using the Hybrid Autoregressive Transducer decoder instead of the original one. The original prednet has not just language token prediction but also fused blank tokens prediction, and it skews the next token prediction into two separate logic paths, which you can't really separate due to the global softmax. HAT has different heads for token and blank prediction so it's safer to use the lm head as a pure lm to determine entropy or negative log likelihood

vadimkantorov commented 8 months ago

@titu1994 I guess one could also ask a more general question: to which extent is the LM learned in transducer-like models is actually benefitting from being a strong LM (with high metrics traditionally measured for LMs like perplexity or some next-token prediction metrics)? Maybe one way to test would be to freeze the LM in transducer and learn a linear probe for next-token prediction task? Have you tried doing that?

Probably, this question is better tried in more shallow fusion of acoustic-model scores and language-model scores. Do you know of any literature that studied this question systematically?

If a good LM leads to better ASR quality, then it would seem logical to try to get a larger pretrained LM (at least in the quest of getting the bestest WERs and reusing the LLMs pretrained on a larger corpus)

On the other hand, it might be that with strong AM, the LM in transducer models just learns to correct certain input phrases if acoustic model struggles to get them right. So it would be very interesting to learn of some study on this.

titu1994 commented 8 months ago

Transducer LM is not at all necessary to be strong. Infact, there's a Microsoft paper which says a stateless conv1d with fixed context size of just 2 tokens (current and previous token), can get almost same performance as a full fledged LSTM decoder.

In our experience, acoustic model is the primary source of measurable improvement. A linear probe experiment sounds cool, but I don't expect superb results on perplexity.

In NeMo, we also implement stateless decoder for TDT.

vadimkantorov commented 8 months ago

This is very interesting and rather counter-intuitive as it somehow goes against the obvious idea of taking some a pretrained powerful LM :) - similar to how re-ranking can be done with large LMs

github-actions[bot] commented 7 months ago

This issue is stale because it has been open for 30 days with no activity. Remove stale label or comment or this will be closed in 7 days.

github-actions[bot] commented 7 months ago

This issue was closed because it has been inactive for 7 days since being marked as stale.