microsoft / torchscale

Foundation Architecture for (M)LLMs
https://aka.ms/GeneralAI
MIT License
2.98k stars 201 forks source link

Training RetNet on A100 GPUs #83

Open Antoine-Bergerault opened 8 months ago

Antoine-Bergerault commented 8 months ago

Hello,

I followed the blog post https://zenn.dev/selllous/articles/retnet_tutorial shared in #52 in order to train RetNet, and it seems to work well for small models (< 3B).

But I am unable to train retnet_3b without running into memory issues. For now I just want to make it run, but even with very small batch-size and max-tokens I run into issues.

cd torchscale/examples/fairseq/
python train.py ../../../fairseq/examples/language_model/data-bin/wikitext-103 \
  --task language_modeling \
  --save-dir checkpoints/retnet_3b/transformer_wikitext-103 \
  --arch retnet_3b --share-decoder-input-output-embed \
  --save-interval 1 \
  --dropout 0.1 \
  --optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 \
  --lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
  --max-tokens 512 --update-freq 16 \
  --fp16 \
  --batch-size 2 \
  --max-update 1 \
  --tokens-per-sample 512

It seems like the backward pass always introduces OOM issues since the call to optimizer.step() in fairseq_task.py, line 498 exits with:

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 11.96 GiB. GPU 0 has a total capacty of 79.15 GiB of which 7.20 GiB is free. Process 75705 has 71.94 GiB memory in use. Of the allocated memory 65.77 GiB is allocated by PyTorch, and 3.41 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

What would you recommend for training this size of model? Is there a way to train it on one or more A100 GPUs with 80GiB of memory?

I understand that I might want to partition the model into multiple GPUs, but I am very unfamiliar with this and any help would be appreciated.

shumingma commented 7 months ago

You can try --memory-efficient-fp16 --checkpoint-activations which can signficantly reduce the memory consumption.