k2-fsa / icefall

https://k2-fsa.github.io/icefall/
Apache License 2.0
930 stars 295 forks source link

OOM in rnn-t training #521

Closed yangsuxia closed 2 years ago

yangsuxia commented 2 years ago

Background information: I use recipes in wenetspeech for training. There's no problem with oom in scan_pessimistic_batches_for_oom, but it's always oom in the later training. it will not work even max_duration is setted 60.

Error information: RuntimeError: CUDA out of memory. Tried to allocate 50.73 GiB (GPU 0; 31.75 GiB total capacity; 3.75 GiB already allocated; 26.30 GiB free; 4.27 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Environment: pytorch-1.12, k2-1.17, python-3.7

Gpu:v100, 32g memory, 8

Data: more than 10000 hours of data, 8000 sampling rate, 40 dimensional features, audio length between 1-15s

csukuangfj commented 2 years ago

Some people had CUDA OOM errors with CUDA 11.1 in https://github.com/k2-fsa/icefall/issues/247

I am not sure which CUDA version you are using. Can you try torch 1.10 + CUDA 10.2 ?

yangsuxia commented 2 years ago

Thank you very much for your reply. I'll try it

yangsuxia commented 2 years ago

Why does the GPU which tot_loss is Nan use so much memory image

yangsuxia commented 2 years ago

After change the environment to pytorch1.11+python3.8.5+k2-1.17+cuda10.2, and the max_duration is 120, it is still oom,the gpu=4 is oom。 Before the error,the memory was still used with 15G, as shown in the figure above image

csukuangfj commented 2 years ago

Are you using your own dataset? Does reducing max duration help?

yangsuxia commented 2 years ago

Yes, I use my own dataset. I try to reduce max_duration to 60. When the batch is 78476, Oom still appears. 50G memory is required for error reporting

csukuangfj commented 2 years ago

When the batch is 78476

What does this mean ? Does it mean the batch size is 78476 or does it throw the OOM at the 78476th batch?

yangsuxia commented 2 years ago

it throw the OOM at the 78476th batch

csukuangfj commented 2 years ago

it throw the OOM at the 78476th batch

Could you post more log messages?

There must be a .pt file saved to your exp_dir (check the log message in the terminal). You can load it with torch.load() and check what it contains and see if there are any peculiar things in the batch, e.g., number of tokens, utterance duration, etc.

yangsuxia commented 2 years ago

number of tokens:8007 Parameters:93M Load 72000.pt,see the following figure for the sampler information: image The detailed error information is shown in the figure below: image

csukuangfj commented 2 years ago

I see. pruned_transducer_stateless2/train.py in wenetspeech does not have the following function, which will save the problematic batch on exception. https://github.com/k2-fsa/icefall/blob/7157f62af3b7712eda2186f7e3d253df7cde65b5/egs/librispeech/ASR/pruned_transducer_stateless2/train.py#L1007

Could you copy the code from librispeech to wenetspeech and post the log after the change?

If possible, would you mind making a PR with your changes about adding display_and_save_batch to wenentspeech?

csukuangfj commented 2 years ago

@luomingshuang

luomingshuang commented 2 years ago

Are you sure that you have used trim_to_supervisions in compute_fbank_feature_wenetspeech.py to clip long utterance to short utterances?

luomingshuang commented 2 years ago

I notice that your number of tokens is 8007. In my training for wenetspeech, the number of tokens is 5537. Comparing the two numbers, the 8007 is too large. I think it may causes more computing memory. If possible, I suggest that you can have a try with 5537 tokens. You can get the tokens.txt with 5537 tokens in https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2/blob/main/data/lang_char/tokens.txt .

pkufool commented 2 years ago

I notice that your number of tokens is 8007. In my training for wenetspeech, the number of tokens is 5537. Comparing the two numbers, the 8007 is too large. I think it may causes more computing memory. If possible, I suggest that you can have a try with 5537 tokens. You can get the tokens.txt with 5537 tokens in https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2/blob/main/data/lang_char/tokens.txt .

I think he uses his own datasets, your tokens might not be suitable for him. I'd rather he can dump the problematic batches, so we can debug the code to see if there is any bugs in our loss.

danpovey commented 2 years ago

It is not a regular OOM issue, it's a bug because it's trying to allocate nearly 60G. He needs to run it in gdb with e.g.

gdb --args python3 script.py args...
(gdb) r

so we can get a stack trace. It may be the same issue reported in #396 ... the backprop of PyTorch happens in C++, so the only kind of stack trace that might help us here is the C++ stack trace.

yangsuxia commented 2 years ago

I see. pruned_transducer_stateless2/train.py in wenetspeech does not have the following function, which will save the problematic batch on exception.

https://github.com/k2-fsa/icefall/blob/7157f62af3b7712eda2186f7e3d253df7cde65b5/egs/librispeech/ASR/pruned_transducer_stateless2/train.py#L1007

Could you copy the code from librispeech to wenetspeech and post the log after the change?

If possible, would you mind making a PR with your changes about adding display_and_save_batch to wenentspeech?

OK, I'll try

yangsuxia commented 2 years ago

Are you sure that you have used trim_to_supervisions in compute_fbank_feature_wenetspeech.py to clip long utterance to short utterances?

This may be the problem. I used the features I extracted earlier. Can the pre extracted features be segmented?

The training data information I use is as follows: data_info

yangsuxia commented 2 years ago

Are you sure that you have used trim_to_supervisions in compute_fbank_feature_wenetspeech.py to clip long utterance to short utterances?

I just looked at the function trim_to_supervision.The audio in wenetspeech is very long, so it is necessary to trim according to supervision. However, my data is distributed between 1-15s, it should not be trim, so I still can't determine what the problem is.

csukuangfj commented 2 years ago

so I still can't determine what the problem is.

Could you follow https://github.com/k2-fsa/icefall/issues/521#issuecomment-1206208083 to get the batch that causes the OOM?

yangsuxia commented 2 years ago

so I still can't determine what the problem is.

Could you follow #521 (comment) to get the batch that causes the OOM?

OK, I will try it when my GPU is free. According to my GPU, what is the appropriate duration to set? 180?

csukuangfj commented 2 years ago

OK, I will try it when my GPU is free. According to my GPU, what is the appropriate duration to set? 180?

You can try 200 first and use nvidia-smi to view the GPU RAM usage. If it is low, you can increase it. If it results in OOM, you can reduce it.

yangsuxia commented 2 years ago

The problem has been solved. Thank you for all replies. It is the problem of my data. The length of some text in the data is too long, resulting in a sudden increase in memory.