k2-fsa / icefall

https://k2-fsa.github.io/icefall/
Apache License 2.0
802 stars 270 forks source link

MGB2 #396

Closed AmirHussein96 closed 1 year ago

AmirHussein96 commented 1 year ago

This is a pull request for MGB2 recipe. Kindly note that the model is still running and currently at epoch 3, see the training curves here https://tensorboard.dev/experiment/zy6FnumCQlmiO7BPsdCmEg/#scalars. One issue is that with the current setup one epoch on 2GPUs V-100 32GB with --max-duration 100, takes 2 days which is very long compared to similar architecture with Espnet (1/2 day for 1 epoch ), any ideas what could cause this? I tried to increase the --max-duration to 200 but it gave me OOM error.

On the other hand the WER on test = 23.53, looks reasonable given that this is still 3rd epoch. I expect to get something close to Espnet (Transformer 14.2, Conformer 13.7).

csukuangfj commented 1 year ago

Thanks!

How did you choose the following thresholds?

https://github.com/k2-fsa/icefall/blob/68aa924eeb08a415fac9061df4482b9c491b76c0/egs/mgb2/ASR/conformer_ctc/train.py#L640-L649

Could you update ./local/display_manifest_statistics.py?

Also, could you please try our pruned RNN-T recipe, which not only has a lower WER on LibriSpeech/GigaSpeech but also has a faster decoding speed with much less memory consumption?

I would recommend you using https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5 as a starting point.

csukuangfj commented 1 year ago

Also, I think you are converting kaldi manifests to lhotse format. Please have a look at https://github.com/k2-fsa/icefall/discussions/391#discussioncomment-2857089

If you use a version of lhotse before https://github.com/lhotse-speech/lhotse/pull/729 to extract the features, I would suggest you to re-extract it using the latest lhotse, which uses lilcom_chunky instead of lilcom_hdf5.

AmirHussein96 commented 1 year ago

lilcom_hdf5

Yes my current version uses lilcom_hdf5, I will rerun it using lilcom_chunky and let you know. Thank you.

AmirHussein96 commented 1 year ago

Thanks!

How did you choose the following thresholds?

https://github.com/k2-fsa/icefall/blob/68aa924eeb08a415fac9061df4482b9c491b76c0/egs/mgb2/ASR/conformer_ctc/train.py#L640-L649

Could you update ./local/display_manifest_statistics.py?

Also, could you please try our pruned RNN-T recipe, which not only has a lower WER on LibriSpeech/GigaSpeech but also has a faster decoding speed with much less memory consumption?

I would recommend you using https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5 as a starting point.

The min=0.5 and max=30 duration boundaries are similar to what I used with Espnet based on my experience. Longer segments > 30 cause memory issues and model underfitting (needs a lot of epochs to start fitting the training data). I will check and update ./local/display_manifest_statistics.py.

Regarding RNN-T I was actually considering it as my next step, so yes I will run it as well, thank you for pointing me to the latest best RNN-T configuration.

AmirHussein96 commented 1 year ago

So to use lilcom_chunky I should change storage_type from LilcomHdf5Writer to LilcomChunkyWriter in compute_fbank_mgb2.py, right? Should I do the same with musan? cut_set = cut_set.compute_and_store_features( extractor=extractor, storage_path=f"{outputdir}/feats{partition}",

when an executor is specified, make more partitions

            num_jobs=num_jobs if ex is None else 80,
            executor=ex,
            storage_type=LilcomChunkyWriter,
        )
csukuangfj commented 1 year ago

So to use lilcom_chunky I should change storage_type from LilcomHdf5Writer to LilcomChunkyWriter in compute_fbank_mgb2.py, right?

Yes. Please see https://github.com/k2-fsa/icefall/blob/1235e23fbfbc90d35311158aa2a9121c8278f001/egs/librispeech/ASR/local/compute_fbank_librispeech.py#L94

Should I do the same with musan?

Yes, you can do that. Please see https://github.com/k2-fsa/icefall/blob/1235e23fbfbc90d35311158aa2a9121c8278f001/egs/librispeech/ASR/local/compute_fbank_musan.py#L95

Notes that filenames end with jsonl.gz, not json.gz and you HAVE TO use to_file(), not to_json().

Also, please replace load_manifest with load_manifest_lazy, see https://github.com/k2-fsa/icefall/blob/1235e23fbfbc90d35311158aa2a9121c8278f001/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py#L227-L229

https://github.com/k2-fsa/icefall/blob/1235e23fbfbc90d35311158aa2a9121c8278f001/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py#L424-L426

And use DynamicBucketingSampler to replace BucketingSampler, see https://github.com/k2-fsa/icefall/blob/1235e23fbfbc90d35311158aa2a9121c8278f001/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py#L308-L316

AmirHussein96 commented 1 year ago

Also I think I closed the PR by mistake, @csukuangfj can you reopen it?

AmirHussein96 commented 1 year ago

After making the suggested modifications for feature storing storage_type=LilcomChunkyWriter, and loading load_manifest_lazy and also I managed to increase the --max-duration from 150 to 300, the 6k iteration takes twice what it took previously. The new setup training curve can be found here https://tensorboard.dev/experiment/N8X0P5pHQyiWwvdT7RTp8w/. I note here that I am still using 2GPUs V-100 32 GB Should I try --storage-type numpy_hdf5 ?

csukuangfj commented 1 year ago

I managed to increase the --max-duration from 150 to 300, the 6k iteration takes twice what it took previously

Since you are doubling the max duration, the time for 6k iterations should also be increased. But I am not sure whether it is normal that the time is doubled. @pzelasko Do you have any comments?

Should I try --storage-type numpy_hdf5 ?

I don't know whether switching to numpy_hdf5 will help you decrease the training time. I think you were using chunked_lilcom_hdf5, i.e., ChunkedLilcomHdf5Writer

danpovey commented 1 year ago

It is pretty close to linear, the change of time per minibatch when you increase the --max-duration, so I think that is as expected.

pzelasko commented 1 year ago

LilcomChunky and LilcomHdf5 should have very close performance, I don’t expect you’d win anything here.

Like Dan says, if you scaled up the batch size by 2x it can explain that it takes almost twice as long to run (unless you had a small model that underutilizes the GPU which is likely not the case here).

AmirHussein96 commented 1 year ago

Related to the slow training discussion, @danpovey suggested to either :

So first I double checked the node hard drive lsblk -d -o name,rota and it looks like SSD:

NAME    ROTA
nvme0n1    0
nvme1n1    0

So as the first attempt I tried increasing the number of workers from 2 to 8, the --max-duration is150 similar to https://tensorboard.dev/experiment/zy6FnumCQlmiO7BPsdCmEg/#scalars because 300 and 200 gave OOM. The speed of the iterations indeed became 4 times faster, you can find the new setup with 8 workers: https://tensorboard.dev/experiment/WvSg4yn8SYyJlKyQGkls0A/#scalars.

danpovey commented 1 year ago

Increasing num-workers increases RAM utilization but does not increase GPU memory utilization so it should not affect the maximum --max-duration you can use.

My feeling is that the issue is that he is running from a HDD, not an SDD, so the latency of disk access is quite slow. I think a solution would be to either use WebDataset for sequential access, or just use a much larger number of workers. If he use jsonl in the recipe, that should prevent the large num-workers from causing too-excessive memory use. (note: a recent PR in icefall from @csukuangfj , possibly just merged make some changes to use jsonl not json).

csukuangfj commented 1 year ago

@AmirHussein96 I would suggest you using the following two files as a reference:

csukuangfj commented 1 year ago

If your dataset is quite large, you can use the following two files as a reference, which splits the dataset into smaller pieces:

AmirHussein96 commented 1 year ago

@AmirHussein96 I would suggest you using the following two files are a reference:

Yes I followed these changes with increasing the number of workers from 2 to 8 per GPU, and I am using 2 GPUs. The utilization is shown below, it is much better now 10h-12h per epoch compared to 2 days previously: Capture

csukuangfj commented 1 year ago

@AmirHussein96 Please use load_manifest instead of load_manifest_lazy if the manifest can be read into the memory all at once. It can speed up things a lot.

AmirHussein96 commented 1 year ago

@csukuangfj

Recent updates:

The conformer_ctc training for 45 epochs has finished, the tensorboard is here: https://tensorboard.dev/experiment/QYNzOi52RwOX8yvtpl3hMw/#scalars

I tried the following decoding methods: (Note I had to reduce the max_active_states from 10000 to 5000 to fit on P100 16GB GPU.

  1. whole-lattice-rescoring: ./conformer_ctc/decode.py --epoch 45 --avg 5 --exp-dir conformer_ctc/exp_5000_att0.8 --lang-dir data/lang_bpe_5000 --method whole-lattice-rescoring --nbest-scale 0.5 --lm-dir data/lm --max-duration 30 --num-paths 1000 --num-workers 20 results=> dev: 15.62 , test: 15.01

  2. Attention-decoder: ./conformer_ctc/decode.py --epoch 45 --avg 5 --max-duration 30 --num-paths 1000 --exp-dir conformer_ctc/exp_5000_att0.8 --lang-dir data/lang_bpe_5000 --method attention-decoder --shuffle False --enable-musan False --enable-spec-aug False --nbest-scale 0.5 --num-workers 20 results=> dev: 15.89 , test: 15.08

Looks like there is still considerable gap compared to similar Espnet setup WER with decoding beam search=20, no LM: (Transformer=> dev: 14.6 , test: 14.2 ; Conformer => test: 13.7) https://github.com/espnet/espnet/blob/master/egs/mgb2/asr1/RESULTS.md

AmirHussein96 commented 1 year ago

Thanks!

How did you choose the following thresholds?

https://github.com/k2-fsa/icefall/blob/68aa924eeb08a415fac9061df4482b9c491b76c0/egs/mgb2/ASR/conformer_ctc/train.py#L640-L649

Could you update ./local/display_manifest_statistics.py?

Also, could you please try our pruned RNN-T recipe, which not only has a lower WER on LibriSpeech/GigaSpeech but also has a faster decoding speed with much less memory consumption?

I would recommend you using https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5 as a starting point. @csukuangfj

I tried the RNNT on MGB2 with the following command ./pruned_transducer_stateless5/train.py \ --world-size 4 \ --num-epochs 40 \ --start-epoch 1 \ --exp-dir pruned_transducer_stateless5/exp \ --max-duration 30 and this is the tensorboard https://tensorboard.dev/experiment/xuOlsEwGRay3qspf7HezLw/#scalars&_smoothingWeight=0.693 The validation loss looks good but the training loss is weird, is this expected? [errors-7078120.txt](https://github.com/k2-fsa/icefall/files/9000144/errors-7078

For some reason the RNNT asks for a lot of memory that does not fit into V100 16GB, any ideas why this is happening? errors-7078120.txt 120.txt)

danpovey commented 1 year ago

It looks like, for 1 of your jobs, an inf has got into the pruned_loss at some point. But this may only affect the diagnostics. Which version of the scripts were you using here, I don't see these scripts in your PR. Edit: in the logs I see that it is pruned_transducer_stateless5. You can try @yaozengwei 's recent PR where he simplifies the RandomCombine module, removing the linear layers. I have seen experiments diverge, in half precision, due to the Linear module in the RandomCombine module causing large outputs, which has been removed in that PR. That is possibly the issue, anyway.

AmirHussein96 commented 1 year ago

It looks like, for 1 of your jobs, an inf has got into the pruned_loss at some point. But this may only affect the diagnostics. Which version of the scripts were you using here, I don't see these scripts in your PR. Edit: in the logs I see that it is pruned_transducer_stateless5. You can try @yaozengwei 's recent PR where he simplifies the RandomCombine module, removing the linear layers. I have seen experiments diverge, in half precision, due to the Linear module in the RandomCombine module causing large outputs, which has been removed in that PR. That is possibly the issue, anyway.

Hi @danpovey , apologize for the late reply, I have pushed the updated pruned transducer stateless config that I am using with MGB2 please check it and let me know what do you think.

The details about k2 version I am using are below:


Build type: Release
Git SHA1: 3c606c27045750bbbb7a289d8b2b09825dea521a
Git date: Mon Jun 27 03:06:58 2022
Cuda used to build k2: 10.2
cuDNN used to build k2: 8.0.5
Python version used to build k2: 3.8
OS used to build k2: Red Hat Enterprise Linux Server release 7.8 (Maipo)
CMake version: 3.18.0
GCC version: 8.4.0
CMAKE_CUDA_FLAGS:   -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w  --expt-extended-lambda -gencode arch=compute_60,code=sm_60 -D_GLIBCXX_USE_CXX11_ABI=0 --compiler-options -Wall  --compiler-options -Wno-strict-overflow  --compiler-options -Wno-unknown-pragmas
CMAKE_CXX_FLAGS:  -D_GLIBCXX_USE_CXX11_ABI=0 -Wno-unused-variable  -Wno-strict-overflow
PyTorch version used to build k2: 1.7.1
PyTorch is using Cuda: 10.2 ```

The lhotse version I am using is: `1.3.0.dev+git.a07121a.clean`
The icefall version I am using is: 
``` >>> icefall.get_env_info()
{'k2-version': '1.16', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '3c606c27045750bbbb7a289d8b2b09825dea521a', 'k2-git-date': 'Mon Jun 27 03:06:58 2022', 'lhotse-version': '1.3.0.dev+git.a07121a.clean', 'torch-version': '1.7.1', 'torch-cuda-available': False, 'torch-cuda-version': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'test', 'icefall-git-sha1': 'e24e6ac-dirty', 'icefall-git-date': 'Mon Jun 27 01:23:06 2022', 'icefall-path': '/alt-arabic/speech/amir/k2/tmp/icefall', 'k2-path': '/alt-arabic/speech/amir/k2/tmp/k2/k2/python/k2/__init__.py', 'lhotse-path': '/alt-arabic/speech/amir/k2/tmp/lhotse/lhotse/__init__.py', 'hostname': 'cribrighthead001', 'IP address': '10.141.255.254'} ```
AmirHussein96 commented 1 year ago

It looks like, for 1 of your jobs, an inf has got into the pruned_loss at some point. But this may only affect the diagnostics. Which version of the scripts were you using here, I don't see these scripts in your PR. Edit: in the logs I see that it is pruned_transducer_stateless5. You can try @yaozengwei 's recent PR where he simplifies the RandomCombine module, removing the linear layers. I have seen experiments diverge, in half precision, due to the Linear module in the RandomCombine module causing large outputs, which has been removed in that PR. That is possibly the issue, anyway.

I recently run again the pruned stateless transducer experiment with v100 32GB GPU and I am still getting OOM. The script tried to allocate 64GB which is unreasonable. Any idea how to solve this? Note that I do not have the inf issue anymore. https://tensorboard.dev/experiment/3Ib5cVt7R5Kq2YW05j56Fg/

log-train-2022-06-29-16-03-10.txt

errors-7091352.txt

csukuangfj commented 1 year ago

Can you update egs/mgb2/ASR/local/display_manifest_statistics.py to match your dataset? Just want to see what is the duration distribution of your data. It helps to select the threshold for https://github.com/k2-fsa/icefall/blob/b08f4424cc37e81151010cc41a837dc8bc12b0bc/egs/mgb2/ASR/pruned_transducer_stateless5/train.py#L960

AmirHussein96 commented 1 year ago

Done

danpovey commented 1 year ago

OK, it seems to be happening in the autograd backward. I'm not aware of any part of the autograd backward of this stuff that should allocate more memory than the forward, so this looks to mel ike a bug. unfortunately because the backward happens in C++, there is no way to get a python-level trace that will tell us very much. What you can do, though, is run it in gdb and do catch throw to catch the exception when it happens, then look at the C++ stack trace too see what it tells us. This will be better if you do it with a version of k2 that has been compiled with debug info. (E.g. compiled locally). You should be able to do something like: gdb --args python3 your_script args ... (gdb) catch throw (gdb) r

... hopefully other exceptions won't be thrown from the C++ level that would cause false alarms.


-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/local/QCRI/ahussein/anaconda3/envs/k2/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
    fn(i, *args)
  File "/alt-arabic/speech/amir/k2/tmp/icefall/egs/mgb2/ASR1/pruned_transducer_stateless5/train.py", line 1001, in run
    train_one_epoch(
  File "/alt-arabic/speech/amir/k2/tmp/icefall/egs/mgb2/ASR1/pruned_transducer_stateless5/train.py", line 765, in train_one_epoch
    scaler.scale(loss).backward()
  File "/home/local/QCRI/ahussein/anaconda3/envs/k2/lib/python3.8/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/local/QCRI/ahussein/anaconda3/envs/k2/lib/python3.8/site-packages/torch/autograd/__init__.py", line 130, in backward
    Variable._execution_engine.run_backward(
RuntimeError: CUDA out of memory. Tried to allocate 64.07 GiB (GPU 0; 31.75 GiB total capacity; 6.42 GiB already allocated; 19.10 GiB free; 11.32 GiB reserved in total by PyTorch)
AmirHussein96 commented 1 year ago

Apologize for the late reply but all powerful GPUs were reserved and I just got access to one. I tried to debug the OOM issue with pdb and it seems that the batch of 4 segments with each 24 sec was too large with the current configuration. So I reduced the max segment to 20 sec, encoder size from 24 -> 12, warup-steps 5000 -> 25000 input size 384 -> 256 output size 512 -> 256 and it seems to work now now the pruned loss goes up while the rest of losses goes down, is that normal? https://tensorboard.dev/experiment/dkYPu7vHR2ywfYAXYHxHqA/#scalars

csukuangfj commented 1 year ago

Yes, it's normal. The pruned loss will go down after 3000 batches if you use the default settings as librispeech.

AmirHussein96 commented 1 year ago

@csukuangfj @danpovey Getting OOM again :( I tried to debug the batch that caused the issue but it looks ok for me. Another thing to note is that I am using 1 V100 GPU with 32GB and it took around 3 days to catch this error which is still in Epoch1. I am using 4 workers, going beyond 4 workers causes OOM from the beginning of training.

Command I am using is this:

./pruned_transducer_stateless5/train.py --world-size 1 --num-epochs 40 \ --start-epoch 1 \--exp-dir pruned_transducer_stateless5/exp --max-duration 100 --bucketing-sampler 1 --num-buckets 50

The debugging is below

    Variable._execution_engine.run_backward(
RuntimeError: CUDA out of memory. Tried to allocate 62.73 GiB (GPU 0; 31.75 GiB total capacity; 2.84 GiB already allocated; 4.19 GiB free; 10.36 GiB reserved in total by PyTorch)
Uncaught exception. Entering post mortem debugging
Running 'cont' or 'step' will restart the program
> /home/local/QCRI/ahussein/anaconda3/envs/k2/lib/python3.8/site-packages/torch/autograd/__init__.py(130)backward()
-> Variable._execution_engine.run_backward(
(Pdb) l
125         grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
126         grad_tensors_ = _make_grads(tensors, grad_tensors_)
127         if retain_graph is None:
128             retain_graph = create_graph
129
130  ->     Variable._execution_engine.run_backward(
131             tensors, grad_tensors_, retain_graph, create_graph,
132             allow_unreachable=True)  # allow_unreachable flag
(Pdb) u
> /home/local/QCRI/ahussein/anaconda3/envs/k2/lib/python3.8/site-packages/torch/tensor.py(221)backward()
-> torch.autograd.backward(self, gradient, retain_graph, create_graph)
(Pdb) l
216                     relevant_args,
217                     self,
218                     gradient=gradient,
219                     retain_graph=retain_graph,
220                     create_graph=create_graph)
221  ->         torch.autograd.backward(self, gradient, retain_graph, create_graph)
(Pdb) u
> /alt-arabic/speech/amir/k2/tmp/icefall/egs/mgb2/ASR1/pruned_transducer_stateless5/train.py(766)train_one_epoch()
-> scaler.scale(loss).backward()
(Pdb) print(loss)
tensor(inf, device='cuda:0', grad_fn=<AddBackward0>)
(Pdb) batch['inputs'].shape
torch.Size([13, 757, 80])
(Pdb) for i in range(0,17):  batch['supervisions']['cut'][i].duration
7.57
7.5444375
7.57
7.57
7.5333125
7.57
7.53
7.5181875
7.4818125
7.4545625
7.57
7.57
7.363625

more detail in the log file attached log-train-2022-08-02-13-36-21.txt log-train-2022-08-02-13-36-21.txt

danpovey commented 1 year ago

The pruned loss is inf. @csukuangfj perhaps we should modify the code that sums up the losses, in icefall/icefall/, to filter out inf's and nan's before incorporating them into the sum? The inf loss is likely caused by a transcript that was too long for the audio, and once it gets into the averaged/total stats it stays there, although it doesn't hurt the training.

The OOM appears to be caused by some kind of bug rather than simply a too-large batch. You are actually using a very small amount of memory, less than 3G:

 62.73 GiB (GPU 0; 31.75 GiB total capacity; 2.84 GiB already allocated; 4.19 GiB free; 10.36 GiB reserved in total by PyTorch)

The bug happens in backprop so we don't get a good backtrace. What I think needs to be done to debug it, is to run it inside gdb, and do catch throw before r so that we catch C++ exceptions (so we can get a backtrace). Something like

gdb --args python3 [script] [args]
(gdb) catch throw
(gdb) r

[I notice that I told you above to do this, but you used pdb instead.]

BTW, the --num-workers is not relevant here, that is the number of workers in the data loader, and the error is not related to the data loader.

csukuangfj commented 1 year ago

The pruned loss is inf. @csukuangfj perhaps we should modify the code that sums up the losses, in icefall/icefall/, to filter out inf's and nan's before incorporating them into the sum?

Please see #525

AmirHussein96 commented 1 year ago

@danpovey thank you so much for your feedback and I apologize for not using the gdb, I will try again and let you know.

danpovey commented 1 year ago

@AmirHussein96 if there are checkpoint files written you may be able to restart from the checkpoint. [it may even be possible to change the code to load the failing minibatch from disk and run on that, although IDK whether the bug will still happen.] Of course there is no guarantee that the bug will still happen in the same place if you don't run in exactly the same way; but it might.

csukuangfj commented 1 year ago

If you are going to resume training from a checkpoint, please use the changes from #421 to update your train.py

AmirHussein96 commented 1 year ago

I captured the error with gbd but not sure what should I do next

Capture2 gdb_stack.txt

csukuangfj commented 1 year ago

Could you show the output of

(gdb) backtrace
AmirHussein96 commented 1 year ago

Could you show the output of

(gdb) backtrace

Yes I dumbed it here in gdb_stack gdb_stack.txt

csukuangfj commented 1 year ago

Could you reduce your --max-duration and re-try again?

csukuangfj commented 1 year ago

I captured the error with gbd but not sure what should I do next

Capture2 gdb_stack.txt

The current batch size is 47 with 2600 tokens, which is a little large.

AmirHussein96 commented 1 year ago

The current batch size is 47 with 2600 tokens, which is a little large

I reduced --max-duration from 300 to 100 still same issue. The interesting thing is that it happens right after pruned loss gets one of its elements inf Capture2

csukuangfj commented 1 year ago

Could you load the saved batch and look at the specific utterance that is causing inf loss? Is there any special about that utterance?

AmirHussein96 commented 1 year ago

Could you load the saved batch and look at the specific utterance that is causing inf loss? Is there any special about that utterance?

So because from today and throughout next week I will be very busy I just run the script and skipped the batch to see if it will still cause the issue and it looks like the issue is resolved. As you suggested @csukuangfj next week I will investigate the batch and post my findings here.

csukuangfj commented 1 year ago

Thanks!

danpovey commented 1 year ago

Torch has a remarkable amount of template expansion, causing a ridiculously long stack for something simple. I managed to locate the informative stack frame as frame 32!

#32 0x00002aaacf47d504 in at::gather_backward(at::Tensor const&, at::Tensor const&, long, at::Tensor const&, bool) ()

.. the error is in the backward of torch.gather(). I think it checks the indexes are correct in the forward pass, so IMO this would likely indicate a bug in torch.. We don't know whether some simple torch operators like the indexing operator might call gather behind the scenes; so I don't know how much this tells us about where the actual problem comes from.

I think we need to spend some effort debug this, though; others get random-seeming errors and I think we should find out the real cause. If we can reproduce the error from just one problematic batch (is this the case?) that would be great progress. Running that in cuda-memcheck might be one thing we could do (cuda-memcheck is slow, but one batch should be OK); or running with and without CUDA computation by changing the train.py to use cpu as the device, that would check if it's a CUDA-specific problem.

It may be possible to go up the stack ("up" in gdb) until we get a meaningful stack frame like at::gather_backward, and then try to print out local variables; but for that I suspect we need a version of Torch that is compiled with debug information. IDK if that's possible to pip install...

danpovey commented 1 year ago

Incidentally, the code of torch gather_backward is just:

Tensor gather_backward(const Tensor& grad, const Tensor& self, int64_t dim, const Tensor& index, bool sparse_grad) {
  if (sparse_grad) {
    return at::_gather_sparse_backward(self, dim, index, grad);
  }
  return at::zeros(self.sizes(), grad.options()).scatter_add_(dim, index, grad);
}

.. and I just checked that torch 1.7.1 has this same code. It doesn't look like it would have any data-dependent bugs at the point where we got the error (which is when calling at::zeros and before doing the scatteradd). The only reason I can see why this would error, is that self.sizes() is somehow invalid. The most likely reason this would be the case, IMO, is memory corruption of the stack (ie.. on the CPU). I have a hard time seeing how our code would cause memory corruption, since most "interesting" memory accesses would be on the GPU, not the CPU, and the CPU code is quite simple. But who knows. In any case, valgrind/memcheck should be able to see this. Unfortunately with python programs, valgrind spits out a huge number of false warnings, but we may be able to find the true error somehow.
This would likely be too slow if we have to run the entire training set, but if we are able to reproduce the problem with a "bad batch" that would make it much easier.

AmirHussein96 commented 1 year ago

I was just checking the batch that caused the issue and I found the following:

(Pdb) for i in batch['supervisions']['cut']: print(len(i.supervisions[0].text)) 3746 309 408 389 198 284 (Pdb) for i in batch['supervisions']['cut']: print(i.supervisions[0].duration) 29.66 29.57775 29.3818125 28.58 28.311125 28.22

So the issue is not caused by the "inf" from the pruned loss, it is caused by MGB2 data. For some reason one of the texts is extremely long which can not be detected from the duration. I suggest to have additional filtering by the text length.

AmirHussein96 commented 1 year ago

The training so far looks good https://tensorboard.dev/experiment/YyNv45pfQ0GqWzZ898WOlw/#scalars. However when I run decoding:

for method in greedy_search modified_beam_search fast_beam_search; do
  ./pruned_transducer_stateless5/decode.py \
    --epoch 28 \
    --avg 10 \
    --exp-dir ./pruned_transducer_stateless5/exp \
    --max-duration 300 \
    --decoding-method $method \
    --max-sym-per-frame 1 \
    --num-encoder-layers 12 \
    --dim-feedforward 2048 \
    --nhead 8 \
    --encoder-dim 512 \
    --decoder-dim 512 \
    --joiner-dim 512 \
    --use-averaged-model True
done

The results contain lots of repetitions specially of the incomplete words which looks like streaming simulation. The Hyp words are following the same order as the Ref but because of the repetitions WER is horrible = reaches 600. Is there a way to remove the repetitions and incomplete words.

ref=['منذ', 'متى', 'أصبح', 'بوتين', 'قائد', 'القومية', 'العربية', 'وحامل', 'مشعل', 'الممانعة', 'والمقاومة'] hyp=['منذ', 'متى', 'منذ', 'متى', 'منذ', 'متى', 'منذ', 'متى', 'متى', 'متى', 'متىى', 'متى', 'أصبح', 'أصبح', 'وأ', 'أصبح', 'أصبح', 'وأصبح', 'بوتين', 'بوتين', 'بوتين', 'بوتين', 'بوتين', 'بوتين', 'بوتين', 'الوطنية', 'قائد', 'لك', 'قائد', 'قائد', 'قائد', 'قائد', 'قائد', 'قائد', 'قائد', 'قائد', 'قائد', 'قائد', 'قائد', 'قائد', 'قائد', 'قائد', 'قائد', 'القومي', 'القومي', 'القومي', 'القومي', 'القومي', 'القومي', 'القومي', 'القومي', 'القومي', 'القومية', 'القومية', 'العربية', 'القومية', 'العربية', 'العربية', 'العربية', 'العربية', 'العربية', 'العربية', 'العربية', 'العربية', 'العربية', 'العربي', 'وحا', 'وحا', 'وحا', 'وحا', 'وحاملاململ', 'مشعل', 'مش', 'مشعل', 'مشعل', 'مشعلعل', 'مشعلعلعل', 'الممان', 'الممانمان', 'الممانمان', 'الممانعة', 'الممانعة', 'والةمانعة', 'والمقاومة', 'والمقاومة', 'والمقاومة', 'والمقاومة', 'والمقاومة', 'والمقاومة', 'والمقاومة', 'والمقاومة', 'والمقاومة', 'والمقاومة', 'والمقاومة', 'ما', 'قالوم', 'لهمانوميات', 'قال', 'لهومانويني']

ref=['بسبب', 'استدعائه', 'للروس', 'ألم', 'تستنجد', 'بفصائل', 'من', 'مشرق', 'الأرض', 'ومغربها'] hyp=['بسبب', 'بسبب', 'بسبب', 'بسبب', 'بسبب', 'بسبب', 'بسبب', 'بسبب', 'بسبب', 'استد', 'استد', 'استدع', 'استدع', 'استدعائهائهائهائهائهائهائه', 'للروس', 'للروس', 'للروس', 'للروس', 'للروس', 'للروس', 'للروس', 'للروس', 'لللم', 'ألملم', 'ألملم', 'ألم', 'ألم', 'ألم', 'ألملم', 'ألملم', 'ألم', 'تستن', 'تستن', 'تستن', 'تستن', 'تستن', 'تستنجد', 'تستنجدجدجدجدجدجد', 'بف', 'بف', 'بف', 'بف', 'بصائلصائلصائلصائلصائل', 'منصاع', 'من', 'مش', 'من', 'مش', 'مش', 'مش', 'مش', 'مش', 'مش', 'مشرق', 'مشرق', 'مشرققرقق', 'الأرضق', 'الأرض', 'الأرض', 'الأرض', 'الأرض', 'الأرض', 'الأرض', 'ومغرب', 'ومغرب', 'ومغرب', 'ومغرب', 'ومغرب', 'ومغرب', 'ومغربهاربهانهاربهانهايةهاها', 'ماهاربها', 'فيها', 'ما', 'فيها', 'قا']

ref=['ألا', 'يرى', 'مؤيدوا', 'بشار', 'كيف', 'تحول', 'مطار', 'إحميميم', 'في', 'اللاذقية', 'إلى', 'مركز', 'تنسيق'] hyp=['ألا', 'ألا', 'إلا', 'ألا', 'ترى', 'ألا', 'ترى', 'ألا', 'يرىرى', 'يرىرى', 'يرىرى', 'مؤؤ', 'مؤي', 'مؤيدوييوؤيدووو', 'بشارو', 'بشار', 'بشار', 'بشار', 'بشار', 'بشار', 'بشار', 'بشار', 'بشار', 'بشار', 'بشار', 'كيف', 'كيف', 'كيف', 'كيف', 'كيف', 'كيف', 'كيف', 'كيف', 'كيف', 'كيف', 'كيف', 'كيف', 'كيف', 'كيف', 'كيف', 'كيف', 'كيف', 'كيف', 'تحولتحول', 'تحولتحول', 'تحول', 'تحولتحول', 'مطار', 'مطار', 'مطار', 'مطار', 'مطار', 'مطار', 'مطار', 'حميم', 'حميم', 'حميم', 'حميم', 'حميم', 'حميم', 'حميميم', 'فييميم', 'فييم', 'في', 'اللا', 'في', 'اللالا', 'الذقيةذقيةذيةقذية', 'إلىذقية', 'إلى', 'مركز', 'إلى', 'إلى', 'مركز', 'إلى', 'مركز', 'إلى', 'مركز', 'إلى', 'مركز', 'مركز', 'إلى', 'مركز', 'مركز', 'إلى', 'مركز', 'مركزه', 'مركز', 'تنسيق', 'تنسيق', 'تنسيق', 'تنسيق', 'تنسيقنسيقنسيقنسيق', 'سابقإفريقي', 'وقدورةقيبة', 'وقدرة', 'وقد']

danpovey commented 1 year ago

For that extremely long text that you found, can you see what is in it? Try sorting the texts by (num-characters / num-frames), if you can get that information, and see what kinds of things you see. If there are lots of repeated words in the texts, these could also end up being recognized.

AmirHussein96 commented 1 year ago

For that extremely long text that you found, can you see what is in it? Try sorting the texts by (num-characters / num-frames), if you can get that information, and see what kinds of things you see. If there are lots of repeated words in the texts, these could also end up being recognized.

In the training there is no systematic repetition I used this exactly same data with CTC model and the perditions were very good with no repetition issue. For the RNNT I removed all segments that have more than 450 characters form the training using the function below.

def remove_short_and_long_text(c: Cut):
    # Keep only text with charachters between 20 and 450

    return 20 <= len(c.supervisions[0].text) <= 450

train_cuts = train_cuts.filter(remove_short_and_long_text)

In the evaluation set the segments are short as you can see above. The issue is that during the decoding the model produces all the intermediate predictions and repeats them until it finishes the word and continue repeating them until the new word comes.

csukuangfj commented 1 year ago

Are there repeats for all of the three decoding methods?

AmirHussein96 commented 1 year ago

Yes this behavior is in all three decoding methods