NVIDIA / NeMo

A scalable generative AI framework built for researchers and developers working on Large Language Models, Multimodal, and Speech AI (Automatic Speech Recognition and Text-to-Speech)
https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html
Apache License 2.0
12.06k stars 2.51k forks source link

Training conformer_ctc with korean #3243

Closed hslee4716 closed 2 years ago

hslee4716 commented 2 years ago

I am trying to train the conformer-ctc model with the Ksponspeech dataset, which is a Korean speaking dataset.

Ksponspeech - 1000hours / 123GB / 630000 pcm audio files ( fs=16000 / sample_width = 2 / channels=1)

In the first try, with the same settings as the existing conformer ctc configuration guidelines, only about 5000 vocabs were set and trained for Korean. It took about 4 hours per epoch, and when about 5 epochs progressed, the loss hardly decreased, and it was confirmed that the WER was fixed at almost100%. As a result of inference, Only spaces or comma(.) were output

image

In the second try, for the test, the dataset was reduced to about 6300, and the vocab size was also reduced to 3200, but the result of training for about 20 hours ( 460 epochs ), models seems never convergence.

image

After that, I continue to test by changing the hyper parameters little by little, but a similar phenomenon continues to occur.

Are there any problems you can guess?

Is it simply a lack of training time?

Also, I would like to know how much the learning rate affects the model training when using the noam scheduler.


name: "Conformer-CTC-BPE-medium"
model:
  sample_rate: 16000
  log_prediction: true # enables logging sample predictions in the output during training
  ctc_reduction: 'mean_batch'

  train_ds:
    manifest_filepath: /ksponspeech/subword_preparied_211124_3/train
    sample_rate: ${model.sample_rate}
    batch_size: 32 # you may increase batch_size if your memory allows
    shuffle: true
    num_workers: 8
    pin_memory: true
    use_start_end_token: false
    trim_silence: false
    max_duration: 20.0 # it is set for LibriSpeech, you may need to update it for your dataset
    min_duration: 0.1

  validation_ds:
    manifest_filepath: /ksponspeech/subword_preparied_211124_3/test
    sample_rate: ${model.sample_rate}
    batch_size: 16 # you may increase batch_size if your memory allows
    shuffle: false
    num_workers: 8
    pin_memory: true
    use_start_end_token: false

  # recommend small vocab size of 128 or 256 when using 4x sub-sampling
  # you may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py
  tokenizer:
    dir: /home/aistudio/hs/ksponspeech/subword_preparied_211124_3/tokenizer_spe_bpe_v5000  # path to directory which contains either tokenizer.model (bpe) or vocab.txt (wpe)
    type: bpe  # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer)

  preprocessor:
    _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
    sample_rate: ${model.sample_rate}
    normalize: "per_feature"
    window_size: 0.025
    window_stride: 0.01
    window: "hann"
    features: 80
    n_fft: 512
    log: true
    frame_splicing: 1
    dither: 0.00001
    pad_to: 0
    pad_value: 0.0

  spec_augment:
    _target_: nemo.collections.asr.modules.SpectrogramAugmentation
    freq_masks: 2 # set to zero to disable it
    # you may use lower time_masks for smaller models to have a faster convergence
    time_masks: 5 # set to zero to disable it
    freq_width: 27
    time_width: 0.05

  encoder:
    _target_: nemo.collections.asr.modules.ConformerEncoder
    feat_in: ${model.preprocessor.features}
    feat_out: -1 # you may set it if you need different output size other than the default d_model
    n_layers: 18
    d_model: 256

    # Sub-sampling params
    subsampling: striding # vggnet or striding, vggnet may give better results but needs more memory
    subsampling_factor: 16 # must be power of 2
    subsampling_conv_channels: -1 # -1 sets it to d_model

    # Feed forward module's params
    ff_expansion_factor: 4

    # Multi-headed Attention Module's params
    self_attention_model: rel_pos # rel_pos or abs_pos
    n_heads: 4 # may need to be lower for smaller d_models
    # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
    att_context_size: [-1, -1] # -1 means unlimited context
    xscaling: true # scales up the input embeddings by sqrt(d_model)
    untie_biases: true # unties the biases of the TransformerXL layers
    pos_emb_max_len: 5000

    # Convolution module's params
    conv_kernel_size: 31

    ### regularization
    dropout: 0.2 # The dropout used in most of the Conformer Modules
    dropout_emb: 0.0 # The dropout used for embeddings
    dropout_att: 0.1 # The dropout for multi-headed attention modules

  decoder:
    _target_: nemo.collections.asr.modules.ConvASRDecoder
    feat_in: null
    num_classes: -1
    vocabulary: []

  optim:
    name: adamw
    lr: 2.0 # 5 - normal
    # optimizer arguments
    betas: [0.9, 0.98]
    # less necessity for weight_decay as we already have large augmentations with SpecAug
    # you may need weight_decay for large models, stable AMP training, small datasets, or when lower augmentations are used
    # weight decay of 0.0 with lr of 2.0 also works fine
    weight_decay: 0.0

    # scheduler setup
    sched:
      name: NoamAnnealing
      d_model: ${model.encoder.d_model}
      # scheduler config override
      warmup_steps: 1000
      warmup_ratio: null
      min_lr: 1e-6

trainer:
  gpus: -1 # number of GPUs, -1 would use all available GPUs
  num_nodes: 1
  max_epochs: 1000
  max_steps: null # computed at runtime if not set
  val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
  accelerator: ddp
  accumulate_grad_batches: 1
  gradient_clip_val: 0.0
  precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
  log_every_n_steps: 1000  # Interval of logging.
  progress_bar_refresh_rate: 1
  resume_from_checkpoint: /home/aistudio/hs/nemo_experiments/Conformer-CTC-BPE/2021-11-24_00-50-48/checkpoints/Conformer-CTC-BPE--val_wer=0.9739-epoch=95-last.ckpt # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
  num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
  check_val_every_n_epoch: 5 # number of evaluations on validation every n epochs
  sync_batchnorm: true
  checkpoint_callback: false  # Provided by exp_manager
  logger: false  # Provided by exp_manager

exp_manager:
  exp_dir: null
  name: ${name}
  create_tensorboard_logger: true
  create_checkpoint_callback: true
  checkpoint_callback_params:
    # in case of multiple validation sets, first one is used
    monitor: "val_wer"
    mode: "min"
    save_top_k: 3
    always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints

  # you need to set these two to True to continue the training
  resume_if_exists: false
  resume_ignore_no_checkpoint: false

{"audio_filepath": "/korean_speech/original/KsponSpeech_03/KsponSpeech_0273/KsponSpeech_272305.pcm", "duration": 8.65, "text": "방송 공연학과 그담에 스물 한 살 때 영화방송제작학과 다음에 이제 졸업한 학과는 호텔관광학과."}
{"audio_filepath": "/korean_speech/original/KsponSpeech_03/KsponSpeech_0273/KsponSpeech_272132.pcm", "duration": 3.76, "text": "어 서로 연애 상담해주고 그랬는데"}
{"audio_filepath": "/korean_speech/original/KsponSpeech_03/KsponSpeech_0273/KsponSpeech_272801.pcm", "duration": 1.71, "text": "아 그런 사람을 진짜 데려와?"}
{"audio_filepath": "/korean_speech/original/KsponSpeech_03/KsponSpeech_0273/KsponSpeech_272683.pcm", "duration": 3.93, "text": "이제욱 쌤 아 안 이제욱 쌤 원래 안 친했어."}
titu1994 commented 2 years ago

For Korean you would want to measure CER (character error rate not word error rate) - add model.use_cer = True in the model config.

What are the average audio durations ? If the average audio durations is less than 10 seconds, you want to reduce spec augment time masks to 2 instead of 10.

To speed up convergence you should pass load the encoder weights from a pretrained English model - example of how to load the pretrained checkpoint and load weights partially are in the tutorial for ASR finetuning on another language.

hslee4716 commented 2 years ago

Thanks for the advice! I'll try the tutorial

hslee4716 commented 2 years ago

The model seems to converge much faster rate than before. And it start to babble!

But it still seems to take a lot of time :disappointed_relieved:

feddybear commented 2 years ago

Hi @titu1994 , I'm currently trying out CTC-Conformer on Japanese and I set model.use_cer=True. Despite this, the model summary still gives a WERBPE Type on the model summary table (under _wer). What could possibly be the issue? I read somewhere that the constructor is _wer, so why is it not model._wer.use_cer = True. I tried both but the summary table doesn't change to CER. Thanks for your help!

titu1994 commented 2 years ago

The name of the class won't change, nor will the log name. Sadly switching the name of the log would crash exp manager unless you were careful to update the metric being monitored.

But rest assured if your config has cfg.model.use_cer = True, then it is computing cer. You can do it in code as well after the model has been built as shown in the tutorial - however there is a difference for CTC vs RNNT models - CTC has model._wer, RNNT had model.wer.

When you use model.summarize() you can see the result of the modules that exist and it will detail the way to access the WERBPR metric (via _wer or wer)

titu1994 commented 2 years ago

We'd recommend using the config method, simply so that after training and inference the model config tells the restored model to again use CER instead of WER mode. If you make code change only, it will continue to use wer after restoration and give different results.

feddybear commented 2 years ago

I was worried that it wasn't computing for CER when I saw the first epoch predictions being blanks, but after leaving it for several hours it's now putting out reasonable outputs and the losses and "WER" are decreasing. Thank you!

titu1994 commented 2 years ago

Right the initial "blank" token prediction is common in ASR, the loss encourages the model to first predict blank every timesteps then replace blanks with actual subwords/chars

wonwooo commented 1 year ago

@hslee4716 Did u solved the problem? Im training with korean dataset with same conformer-ctc model and my training loss and wer is exactly same as you. It doesn't decrease less then 100. wer is near 100%.

hslee4716 commented 1 year ago

@wonwooo It's been a while so I don't remember much, but I remember that, as @titu1994 titu said at the time, I changed WER to CER, followed the tutorial faithfully, trained at an appropriate learning rate and enough batch. At that time, It tested on temporal server, i cannot check detail now. sry..