Open egoetz opened 5 months ago
This is caused by flash attention. Please disable it and use original self attention.
And also use the default training config, here is batch setting for 4 A100 80g:
30 batch size 1024 block size 4 gradaccum * 4 GPUs = 491,520
How do you disable flash attention? I can't find anything on the torch website which suggests it is togglable.
Is there a way to find the correct configuration for an arbitrary setup? Based off of your comment and the original script I'm not exactly sure when to altern the batch size vs the grad acc:
30 batch size 1024 block size 4 gradaccum 4 GPUs = 491,520 batch_size = 30 block_size = 1024 gradient_accumulation_steps = 4 4 = 16
train_gpt2.py
- 8 A100s, one to two nodes12 batch size 1024 block size 5 gradaccum 8 GPUs = 491,520 batch_size = 12 block_size = 1024 gradient_accumulation_steps = 5 8 = 40
I will attempt to use your suggested parameters with the 4 A100s. I have also had much better luck recently in accessing more GPUs and so I will try to replicate the results of the 8 A100 training discussed in the README. I changed the parameters back to their default values for the 8 GPUs, but I still have flash attention enabled.
Are you suggesting that the divergence may be caused by flash attention instability? I have changed my default dtype value from bfloat16 to float32 in the hope that increasing the precision could address that issue.
I have been trying to follow the steps listed under "reproducing GPT-2" from the README.md. Unfortunately, when I run the model, my training always diverges. I have tried switching up my learning rate and gradient accumulation but neither of these tactics seemed to work, although I did have to fix a bug in the learning rate after varying those parameters. I could try changing those variables again, but my latest runs lead me to think that neither of those parameters are the issue:
Here are the last two runs. The orange run decays the learning rate over 300,000 steps while the pink run decays the learning rate over 600,000 steps. For these runs the learning rate starts at 6e-5 and hits its minimum at 6e-6.
Here are some of my meta-parameters: batch_size = 24 block_size = 1024 max_iters = 300000 lr_decay_iters = 300000 eval_interval = 1000 eval_iters = 200 log_interval = 100 weight_decay = 5e-2
I am running this model on 4 A100 80GB GPUs.