prajdabre / yanmtt

Yet Another Neural Machine Translation Toolkit
MIT License
174 stars 32 forks source link

CUDA error while pre-training BART & how to use --hard_truncate_length #8

Closed GorkaUrbizu closed 2 years ago

GorkaUrbizu commented 2 years ago

Hi again,

After getting the NAN loss error from the previews issue, I launched another training during the weekend:

python3 pretrain_nmt.py -n 1 -nr 0 -g 2 --model_path models/bart_base_512 \
--tokenizer_name_or_path tokenizers/mbart-bpe50k \
--langs xx --mono_src data/train.xx \
--batch_size 4096 \
--multistep_optimizer_steps 16 \
--num_batches 500000 \
--warmup_steps 16000 \
--encoder_layers 6 \
--decoder_layers 6 \
--max_length 512 \
--encoder_attention_heads 12 \
--decoder_attention_heads 12 \
--decoder_ffn_dim 3072 \
--encoder_ffn_dim 3072 \
--d_model 768 \
--fp16

With which I got the following error after 11K steps:

11920 2.9054422
11930 2.8778658
11940 2.9062994
11950 2.906765
11960 2.8594751
11970 2.8594935
terminate called after throwing an instance of 'c10::CUDAError'
  what():  CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stack trace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

(+ many error lines)

I don't know what caused this, so I will run next trainings with CUDA_LAUNCH_BLOCKING=1 activated.

But also I want to use --hard_truncate_length argument in case the problem is caused by the length of sequences.

But I'm not sure if I understand well what --hard_truncate_length argument exactly does. let's say that I want to train a model with --max_length=128 and--batch_size=4096... if I understood correctly, I should set --hard_truncate_length at 4096 too, right?

Thanks for your time. Regards, Gorka

prajdabre commented 2 years ago

If you see something like "Out of memory" in your cuda error, then it's memory related, in which case you should use smaller batches with more gradient accumulation steps.

Unless I see your error log I cant be sure why this happens but its not a bug in the code.

As for the --hard_truncate_length argument, it is used to make sure that the maximum length does not exceed 1024 subwords.

--max_length=128 acts as a truncation for the raw sentence and counts number of words (not sub-words). However, it is possible that there may be a sentence with a random but long string of 1000 characters. So the sentence length is 1 but if you do subword segmentation then it may go beyond 1024 tokens which the model does not handle by default. This is where --hard_truncate_length comes into play. The maximum value it should take is the same as "max_position_embeddings" in the mBART config which is 1024 by default. I usually set it to 256 to account for long sentences and 1024 for documents. If you wish to go beyond this, then you also have to set max_position_embeddings.

My batch construction algo is like this:

  1. Read sentence. If it has more than N (max_length) words then truncate it to N words. (I want to avoid discarding long sentences so I truncate them).
  2. Subword segment it and if its subword segmented length causes the batch size to go over limit then keep it for the next batch and return existing batch. If its subword segmented length does not cause the batch size to go over limit then add it to the current batch.

There is a possibility that a single problematic sentence which, even after subword segmentation, goes above batch size. In this case the code will run in an infinite loop without returning any batches. However this is probably never going to happen. I plan to push a bunch of new updates in the coming few days which takes care of this unlikely issue too. Overall, I would not worry.

GorkaUrbizu commented 2 years ago

Thanks for the details! it was very helpful!

My error doesn't seem related with --hard_truncate_length, but I will set --hard_truncate_length at 1024 just in case from now on. I will ignore the error for now, and if I get it again will come back with the full log asking for your help.

PD: It wasn't an OOM error, I know those already too well hahahaa