k2-fsa / icefall

https://k2-fsa.github.io/icefall/
Apache License 2.0
931 stars 295 forks source link

ctc_loss=0, att_loss=nan, loss=nan in ./conformer_ctc/train.py #580

Closed thaingoc01 closed 7 months ago

thaingoc01 commented 2 years ago

log_prepare_stage14_maxDur500_1GPU_01.txt train.py.txt I modify Gigiaspeech recipe to train my 1200 hrs data (Including short utterances and long utterances with segmentations). I also encountered the invalid input size as at https://github.com/k2-fsa/icefall/issues/320 I apply AmirHussein96 solution to bypass bad batches. But after about 11,000 batches, I got ctc_loss=0, att_loss=nan, loss=nan problem. Training command: ./conformer_ctc/train.py --max-duration 500 --num-workers 1 --world-size 1 --exp-dir conformer_ctc/exp_500 --lang-dir data/lang_bpe_500

I attach the ./conformer_ctc/train.py and log file for reference. Thank you very much for any help or suggestion

python -m k2.version

k2 version: 1.17
Build type: Release
Git SHA1: 3dc222f981b9fdbc8061b3782c3b385514a2d444
Git date: Mon Jul 4 02:13:04 2022
Cuda used to build k2: 11.6
cuDNN used to build k2: 8.2.0
Python version used to build k2: 3.9
OS used to build k2: Ubuntu 18.04.6 LTS
CMake version: 3.18.4
GCC version: 7.5.0
CMAKE_CUDA_FLAGS:  -Wno-deprecated-gpu-targets   -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w  --expt-extended-lambda -gencode arch=compute_35,code=sm_35  -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w  --expt-extended-lambda -gencode arch=compute_50,code=sm_50  -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w  --expt-extended-lambda -gencode arch=compute_60,code=sm_60  -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w  --expt-extended-lambda -gencode arch=compute_61,code=sm_61  -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w  --expt-extended-lambda -gencode arch=compute_70,code=sm_70  -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w  --expt-extended-lambda -gencode arch=compute_75,code=sm_75  -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w  --expt-extended-lambda -gencode arch=compute_80,code=sm_80  -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w  --expt-extended-lambda -gencode arch=compute_86,code=sm_86 -DONNX_NAMESPACE=onnx_c2 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_86,code=compute_86 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=integer_sign_change,--diag_suppress=useless_using_declaration,--diag_suppress=set_but_not_used,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=implicit_return_from_non_void_function,--diag_suppress=unsigned_compare_with_zero,--diag_suppress=declared_but_not_referenced,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -D_GLIBCXX_USE_CXX11_ABI=0 --compiler-options -Wall  --compiler-options -Wno-strict-overflow  --compiler-options -Wno-unknown-pragmas
CMAKE_CXX_FLAGS:  -D_GLIBCXX_USE_CXX11_ABI=0 -Wno-unused-variable  -Wno-strict-overflow
PyTorch version used to build k2: 1.12.0
PyTorch is using Cuda: 11.6
NVTX enabled: True
With CUDA: True
Disable debug: True
Sync kernels : False
Disable checks: False
Max cpu memory allocate: 214748364800
k2 abort: False
csukuangfj commented 2 years ago

But after about 11,000 batches, I got ctc_loss=0, att_loss=nan, loss=nan problem.

Could you print out the batch that causes ctc_loss=0 and attn_loss=nan?

danpovey commented 2 years ago

From the attached log (see above), it seemsit was training normallyuntil a certain point, then always ctc_loss=0, att_loss=nan. I suspect the model is always outputting inf, which gets filtered in the ctc loss computation but not the other loss computations. This seems to have happened suddenly.

2022-09-19 23:40:57,640 INFO [train.py:556] Epoch 0, batch 11300, loss[ctc_loss=0.3759, att_loss=0.4902, loss=0.4559, over 7457.00 frames. ], tot_loss[ctc_loss=0.405, att_loss=0.4955, loss=0.4683, over 1494451.87 frames. ], batch size: 28
2022-09-19 23:41:47,152 INFO [train.py:556] Epoch 0, batch 11350, loss[ctc_loss=0.356, att_loss=0.4576, loss=0.4271, over 6522.00 frames. ], tot_loss[ctc_loss=0.4023, att_loss=0.4919, loss=0.465, over 1484310.24 frames. ], batch size: 24
2022-09-19 23:42:35,838 INFO [train.py:556] Epoch 0, batch 11400, loss[ctc_loss=0, att_loss=nan, loss=nan, over 12124.00 frames. ], tot_loss[ctc_loss=0.3367, att_loss=nan, loss=nan, over 1562841.32 frames. ], batch size: 108
2022-09-19 23:43:23,191 INFO [train.py:556] Epoch 0, batch 11450, loss[ctc_loss=0, att_loss=nan, loss=nan, over 6951.00 frames. ], tot_loss[ctc_loss=0.2653, att_loss=nan, loss=nan, over 1543849.74 frames. ], batch size: 25
2022-09-19 23:44:09,710 INFO [train.py:556] Epoch 0, batch 11500, loss[ctc_loss=0, att_loss=nan, loss=nan, over 6107.00 frames. ], tot_loss[ctc_loss=0.2004, att_loss=nan, loss=nan, over 1590892.37

One possibility would be to change the training code check whether the sum of the encoder output is nan/inf (e.g.: x = encoder_output.sum().item(); if x - x != 0:), and if it is, dump out the model and the optimizer state as a checkpoint and then exit. That would make it possible to look for inf's in the model and in the optimizer state, by looping over the keys of the model dict.

thaingoc01 commented 2 years ago

But after about 11,000 batches, I got ctc_loss=0, att_loss=nan, loss=nan problem.

Could you print out the batch that causes ctc_loss=0 and attn_loss=nan?

Sure. Thank you very much for your quick reply. I'm quite new to this field. How should I print out the batch? Check if ctc_loss=0 then print out logging.info(f"batch {batch_idx}, batch info{batch}") ?

thaingoc01 commented 2 years ago

Thank you! I also suspect it's the problem of inf of very short utterances (but I check duration of utterances or segments, no segment or utterance is less than 0.2 second and it has text inside like "YES", "UH", etc, I'm not sure if I should filter out these short utterances or segments, and which duration should I filter out? May be less than 0.5 second?

I try the training with subset of 140 hrs of data (all short utterances, duration 2-5 seconds). The training goes through and get reasonable WER.

csukuangfj commented 2 years ago

How should I print out the batch?

Please have a look at https://github.com/k2-fsa/icefall/blob/9ae2f3a3c5a3c2336ca236c984843c0e133ee307/egs/librispeech/ASR/pruned_transducer_stateless2/train.py#L785

thaingoc01 commented 2 years ago

How should I print out the batch?

Please have a look at

https://github.com/k2-fsa/icefall/blob/9ae2f3a3c5a3c2336ca236c984843c0e133ee307/egs/librispeech/ASR/pruned_transducer_stateless2/train.py#L785

Hi csukuangfj, I manage to print out the first batch that started to have ctc_loss=0. I attach the output in the file for your reference. The batch seems quite big. I guess because I set --max-duration 500 in training command ./conformer_ctc/train.py But if I reduce this --max-duration to default 120, the ctc_loss=0 still happened at later batches, and the number of batches are a lot and have many skipping batches. Thank you! [Uploading log_prepare_stage14_maxDur500_1GPU_02_debug_extract_batch.txt…]()

csukuangfj commented 2 years ago
Screen Shot 2022-09-24 at 2 32 01 PM

Sorry, I am not able to download it.

thaingoc01 commented 2 years ago
Screen Shot 2022-09-24 at 2 32 01 PM

Sorry, I am not able to download it.

I upload it to Google Drive https://drive.google.com/file/d/1W4vCzn75Jjnh9hbeujWmkPvSmM6X-K_v/view?usp=sharing Looking at the batch I still have no idea.

csukuangfj commented 2 years ago

Could you use https://github.com/k2-fsa/icefall/blob/9ae2f3a3c5a3c2336ca236c984843c0e133ee307/egs/librispeech/ASR/pruned_transducer_stateless2/train.py#L785

to save it to a .pt file?

danpovey commented 2 years ago

Guys, I think this is not about the batch, I think this is an instability issue where there are NaN's or inf's appearing in the model. These will be filtered out by the CTC loss but not the other losses. I think the focus needs to be on the model parameters, not the batch.

danpovey commented 2 years ago

I notice this code is printing out warnings:

        batch_size = len(batch["supervisions"]["text"])
        #Ref: https://github.com/k2-fsa/icefall/issues/320                                                                                                                                                                  
        if batch['inputs'].shape[0] != len(batch["supervisions"]["text"]):
            logging.info("In train_one_epoch, skipping batch, batch_idx = " + str(batch_idx))
            continue;
        else:

.. this is odd. Do you have CutConcatenate enabled in the asr_datamodule.py? [but it might not be related to the error.]

thaingoc01 commented 2 years ago

CutConcatenate

Hi Dan, I check GigiaSpeech recipes, the --concatenate-cuts default setting is false and I don't see anywhere set it true again. I guess no CutConcatenate here. The if else part I added in follow the suggestion of issue #320 https://github.com/k2-fsa/icefall/issues/320 because I have the same Invalid Input Size at beginning.

thaingoc01 commented 2 years ago

Could you use

https://github.com/k2-fsa/icefall/blob/9ae2f3a3c5a3c2336ca236c984843c0e133ee307/egs/librispeech/ASR/pruned_transducer_stateless2/train.py#L785

to save it to a .pt file?

Thanks, I will update you when it's ready.

kobenaxie commented 2 years ago

Training conformer with ctc & pruned_rnnt may get nan or inf on internal dataset, which contains many samples shorter than 1 second, setting zero_infinity=True in torch.nn.CTCLoss looks normal, still waiting for final results.

thaingoc01 commented 2 years ago

Could you use

https://github.com/k2-fsa/icefall/blob/9ae2f3a3c5a3c2336ca236c984843c0e133ee307/egs/librispeech/ASR/pruned_transducer_stateless2/train.py#L785

to save it to a .pt file?

Hi Fangjun, I upload the starting batch with ctc_loss=0 at https://drive.google.com/file/d/1kUJc12Jsi8OjVjeIOLINlArj9jjjKiuS/view?usp=sharing

./conformer_ctc/train.py doesn't have sp, hence I set it to None I also save the epoch 0 but need some time to upload. Maybe tomorrow. Thanks

thaingoc01 commented 2 years ago

Guys, I think this is not about the batch, I think this is an instability issue where there are NaN's or inf's appearing in the model. These will be filtered out by the CTC loss but not the other losses. I think the focus needs to be on the model parameters, not the batch.

Hi Dan,

I saved the model when the batch got ctc_loss=0 and upload at https://drive.google.com/file/d/1dqfxjWzkR-7j1H0gFuF946AbEHIrT7oC/view?usp=sharing if you want to have a look, big file about 1.2GB. I don't know how to look into the model for more info yet. Thank you!

thaingoc01 commented 2 years ago

Hi guys, I will redo the data part. I will remove utterances/segments which are less than 1 second (some still have good speech inside) and repeat the training to see if the error gone or not. I observe that some audio files although about 1 second but empty inside, I guess it's the error while recording, not sure if this cause the error or not. Thank you very much for your help!

danpovey commented 2 years ago

you can load as a dict with torch.load(), and look at the keys and values, e.g. x['model'].keys() will give a list of the parameter names.

danpovey commented 2 years ago

it's not so easy for me to download the file from here. Please look for inf's/nan's yourself.

danpovey commented 2 years ago

All of the parameters are nan. After you get one nan, they tend to propagate. One way to debug it would be to run the training in pdb, and change the code to throw an exception if any part of the loss is nan. Then, from the debugger command line, see what is inf/nan and which parts of the model are inf/nan, which might help narrow it down. To make debugging easier, if it fails consistently at a certain position the code could be changed to write the checkpoint at a fixed batch number just before then, which would make it easier to reproduce the problem. It's been a while since we have looked at this recipe since we have mostly been focusing on RNN-T systems ("prunedtransducer*").

thaingoc01 commented 2 years ago

All of the parameters are nan. After you get one nan, they tend to propagate. One way to debug it would be to run the training in pdb, and change the code to throw an exception if any part of the loss is nan. Then, from the debugger command line, see what is inf/nan and which parts of the model are inf/nan, which might help narrow it down. To make debugging easier, if it fails consistently at a certain position the code could be changed to write the checkpoint at a fixed batch number just before then, which would make it easier to reproduce the problem. It's been a while since we have looked at this recipe since we have mostly been focusing on RNN-T systems ("prunedtransducer*").

Hi Daniel, thank you very much for your support! When encountered the first nan, I exit and save the model right away but it seems the nan still in. I have filtered out the segments which are less than 1 second and the training went good ( still have to filter out bad batches as above). I got quite good WER on reading speech (length usually about 2-7 seconds). However because we filtered out short segments (about 36,000 segments) which are still meaningful and belong to conversation dataset, the WER on conversational test set is worse than traditional Kaldi model.

I'm training another model on pruned_transducer_stateless_5 but I saw inside the code, it would also filter out less than 1 second segments. Have a good weekend!

csukuangfj commented 2 years ago

Could you try #604?

Pruned RNN-T shares the same constraint as CTC training, so you can copy the filtering code from #604 .