NVIDIA / NeMo

A scalable generative AI framework built for researchers and developers working on Large Language Models, Multimodal, and Speech AI (Automatic Speech Recognition and Text-to-Speech)
https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html
Apache License 2.0
12.13k stars 2.52k forks source link

ASR - WER not decreasing after certain point (Finetuning hybrid_cache_aware_streaming model) #10578

Open rkchamp25 opened 1 month ago

rkchamp25 commented 1 month ago

I am trying to finetune STT En FastConformer Hybrid Transducer-CTC Large Streaming Multi model.

Issue: WER does not go below 22.5%. Need suggestion of how I can improve this or anything that I might be doing wrong.

Details:

python3 -u NeMo/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe.py \
    --config-path="../conf/fastconformer/hybrid_cache_aware_streaming" \
    --config-name="fastconformer_hybrid_transducer_ctc_bpe_streaming" \
    +init_from_pretrained_model="stt_en_fastconformer_hybrid_large_streaming_multi" \
    model.train_ds.manifest_filepath="data/train_20.json" \
    model.validation_ds.manifest_filepath="data/val_20.json" \
    model.tokenizer.dir="default_tokenizer" \
    model.tokenizer.type="bpe" \
    trainer.max_epochs=300 \
    model.optim.name="adamw" \
    model.optim.weight_decay=0.0001 \
    model.optim.sched.warmup_steps=2000 \
    model.aux_ctc.ctc_loss_weight=0.3 \
    model.optim.lr=0.005 \
    model.optim.betas=[0.9,0.999] \
    ++exp_manager.exp_dir="checkpoints" \
    ++exp_manager.version=one \
    ++exp_manager.use_datetime_version=False \
    ++exp_manager.resume_ignore_no_checkpoint=True

Config:

cfg:
  sample_rate: 16000
  compute_eval_loss: false
  log_prediction: true
  skip_nan_grad: false
  model_defaults:
    enc_hidden: 512
    pred_hidden: 640
    joint_hidden: 640
  train_ds:
    manifest_filepath: data/train_20.json
    sample_rate: 16000
    batch_size: 8
    shuffle: true
    num_workers: 8
    pin_memory: true
    max_duration: 20
    min_duration: 5
    is_tarred: false
    tarred_audio_filepaths: null
    shuffle_n: 2048
    bucketing_strategy: synced_randomized
    bucketing_batch_size: null
  validation_ds:
    manifest_filepath: data/val_20.json
    sample_rate: 16000
    batch_size: 8
    shuffle: false
    use_start_end_token: false
    num_workers: 8
    pin_memory: true
  test_ds:
    manifest_filepath: null
    sample_rate: 16000
    batch_size: 8
    shuffle: false
    use_start_end_token: false
    num_workers: 8
    pin_memory: true
  tokenizer:
    dir: default_tokenizer
    type: bpe
  preprocessor:
    _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
    sample_rate: 16000
    normalize: NA
    window_size: 0.025
    window_stride: 0.01
    window: hann
    features: 80
    n_fft: 512
    frame_splicing: 1
    dither: 1.0e-05
    pad_to: 0
  spec_augment:
    _target_: nemo.collections.asr.modules.SpectrogramAugmentation
    freq_masks: 2
    time_masks: 10
    freq_width: 27
    time_width: 0.05
  encoder:
    _target_: nemo.collections.asr.modules.ConformerEncoder
    feat_in: 80
    feat_out: -1
    n_layers: 17
    d_model: 512
    use_bias: true
    subsampling: dw_striding
    subsampling_factor: 8
    subsampling_conv_channels: 256
    causal_downsampling: true
    ff_expansion_factor: 4
    self_attention_model: rel_pos
    n_heads: 8
    att_context_size:
    - 70
    - 1
    att_context_style: chunked_limited
    att_context_probs: null
    xscaling: true
    pos_emb_max_len: 5000
    conv_kernel_size: 9
    conv_norm_type: layer_norm
    conv_context_size: causal
    dropout: 0.1
    dropout_pre_encoder: 0.1
    dropout_emb: 0.0
    dropout_att: 0.1
    stochastic_depth_drop_prob: 0.0
    stochastic_depth_mode: linear
    stochastic_depth_start_layer: 1
  decoder:
    _target_: nemo.collections.asr.modules.RNNTDecoder
    normalization_mode: null
    random_state_sampling: false
    blank_as_pad: true
    prednet:
      pred_hidden: 640
      pred_rnn_layers: 1
      t_max: null
      dropout: 0.2
    vocab_size: 1024
  joint:
    _target_: nemo.collections.asr.modules.RNNTJoint
    log_softmax: null
    preserve_memory: false
    fuse_loss_wer: true
    fused_batch_size: 4
    jointnet:
      joint_hidden: 640
      activation: relu
      dropout: 0.2
      encoder_hidden: 512
      pred_hidden: 640
    num_classes: 1024
  decoding:
    strategy: greedy_batch
    greedy:
      max_symbols: 10
    beam:
      beam_size: 2
      return_best_hypothesis: false
      score_norm: true
      tsd_max_sym_exp: 50
      alsd_max_target_len: 2.0
  aux_ctc:
    ctc_loss_weight: 0.3
    use_cer: false
    ctc_reduction: mean_batch
    decoder:
      _target_: nemo.collections.asr.modules.ConvASRDecoder
      feat_in: null
      num_classes: 1024
    decoding:
      strategy: greedy
  interctc:
    loss_weights: []
    apply_at_layers: []
  loss:
    loss_name: default
    warprnnt_numba_kwargs:
      fastemit_lambda: 0.005
      clamp: -1.0
  optim:
    name: adamw
    lr: 0.005
    betas:
    - 0.9
    - 0.999
    weight_decay: 0.0001
    sched:
      name: NoamAnnealing
      d_model: 512
      warmup_steps: 2000
      warmup_ratio: null
      min_lr: 1.0e-06
    target: nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models.EncDecHybridRNNTCTCBPEModel
  nemo_version: 2.1.0rc0

Things changed from default config:

batch size: 8 att_context_size: [70,1] because I need to use this model for real time transcription. Experimented with the Online_ASR_Microphone_Demo_Cache_Aware_Streaming.ipynb and found out that [70,1] is what gives best latency for my use case of real time transcription (lookead_size = 80ms and encoder_step_length = 80ms) learning_rate: Tried with different lr (0.005, 0.0005, 0.0001) but not difference. min_duration = 5s max_duration = 20s

Tokenizer:

Used default tokenizer, saved the default tokenizer from the model and then used it for finetuning. (1024 SPE Unigram)

Dataset:

train: 120hrs of english audio (doctor patient conversations (recordings of meetings), contains medical jargon and american accent), audio files: 5s-20s converted all transcripts to only contain lowercase english alphabets, space and apostrophe val: 15hrs of audio in same format as above

WER of Original Model on val_set = 41% The WER reaches 22-25% in all runs within first 10-15 epochs and after that it remains almost constant. Max epochs tried is 90.

GPU: V100 16GB

Tensorboard Logs of one of the runs: Image Image Image Image Image

UPDATE (30/09/2024)

Tried adding more data to the dataset. Whole dataset is shuffled and then train and val sets are created. Didn't make any difference and instead now I have a bigger lowest VAL_WER with the new daataset.

TRAIN: 220 hrs (~100 hrs more from previous data) VAL: 17 hrs Audio Lengths: 1s - 20s (1s 2s 3s 4s ....... 18s 19s 20s) learning rate: 0.001 Tried with Batch size: 8 (V100), 32 (A100) Epochs: 45 (V100), 95 and 300 (A100) VAL_WER: ~23-24% in all 3 runs

For some reason it looks like the learning rate becomes equal to minimum learning rate (1e-6) right from the start and then stays there. Not sure why. Also don't know why training loss graph looks like as shown below. Also I am not sure what we are training here? CTC or RNNT or both with the config I am using. (I use decoder_type = "rnnt" when doing inference using this model)

Find tensorboard logs attached below

Image Image Image Image Image Image

UPDATE (3/10/2024)

I read in few answers in the issues that when we use NoamAnnealing which is the default scheduler mentioned in config (fastconformer_hybrid_transducer_ctc_bpe_streaming.yaml), the lr is a multiplier and hence should be 1/10th or 1/5th of original

I tried with lr = 0.5 (1 GPU A100 batch_size=32), 1.0 (4 GPUs A100 batch_size=32) Min WER = 22% (44 epochs 1 GPU, 175 epochs 4 GPUs) and it seems almost stuck and doesn't go below this wer with new learning rates also

Apart from this, also tried below changes which also didn't make any difference att_context_size = [70,6] instead of [70,1] accumulate_grad_batches = 8 instead of 1

@nithinraok can you please have a look into this and provide any suggestions or any insights that you have based on the provided details. Thank You

Attaching the logs from latest runs Image Image Image Image Image

@nithinraok @titu1994 @elliottnv Any help/suggestions would be appreciated. Thank You.

github-actions[bot] commented 16 hours ago

This issue is stale because it has been open for 30 days with no activity. Remove stale label or comment or this will be closed in 7 days.