zyushun / Adam-mini

Code for Adam-mini: Use Fewer Learning Rates To Gain More https://arxiv.org/abs/2406.16793
231 stars 9 forks source link

Memory saving only in checkpoint size, not during training #8

Open aditya2331 opened 3 weeks ago

aditya2331 commented 3 weeks ago

Hi, I tried running the normal train_gpt2.py code with Adam_mini. I had to remove the ValueError and Hessian spectrum import to make it work properly. I noticed that there was a good reduction in the checkpoint size but no memory reduction during the training, so I couldn't fit a higher batch size for more throughput. The memory usage came out to be the same as regular Adam while training.

Training Details:

DDP with 4xA100 80GB GPUs, batch size 12, 336 gradient accumulation steps, GPT2 XL model with weight tying removed

chcoliang commented 3 weeks ago

Hi! Thanks for your interest and your question!

We are not sure why there is a reduction in checkpoint size but no memory reduction during training. This issue seems unusual. For GPT2-1.5B (without weight_tying), there should be ~5G memory reduction per-GPU during the training. Here are our suggestions.

  1. Please check if there is ~5 Gb reduction per GPU. If no memory reduction occurs, you can check whether you offload the optimizer to the CPU.

  2. If there indeed is ~5GB reduction, then there should be some room for larger bs. Please try to slowly increase bs to 14, 16, etc.

  3. If none of the above helps, please try putting weight-tying back and see if it is better.

Please feel free to update your findings here!

aditya2331 commented 3 weeks ago

The first 4 GPUs are using adam mini, the latter are using adamw. Interestingly, switching to adamw on your codebase runs into an OOM error, but using the original nanoGPT runs fine.

Image 08-07-24 at 11 44 AM

As can be seen, one GPU is using more memory than the others in adam mini, not sure why. This is the original GPT2 XL model without weight tying. Batch size and gradient accumulation are same as those above (12, 336).

There is no CPU offloading.

zyushun commented 3 weeks ago

Hi @aditya2331 ! Thanks for the update!

Regarding your figure: it seems that there is only ~300m mem reduction on each GPU. This is unexpected because normally it should be ~5GB reduction in your setting. Could you double check your ckpt size to see if there is actually ~5GB reduction?

Regarding your comment "adamw on your codebase runs into an OOM error, but using the original nanoGPT runs fine." : this also seems unusual since our code is essentially nanoGPT. Perhaps a simple way debug is to: import and run Adam-mini on the original nanoGPT code to see if things get better.

Please feel free to further update here!

980202006 commented 3 weeks ago

Same issue in huggingface trainer.

zyushun commented 3 weeks ago

Hi @980202006 ! This issue seems unusual. We do observe substantial memory reduction on Huggingface trainer. We share our results as follows.

Setting: We use Huggingface trainer under Llamafactory codebase. We conduct SFT on Llama2-7b. We use gradient checkpointing, batch size = 4 and DeepSpeed zero3.

The memory usage is shown as follows:

Adam-mini: 34435.0MB  35533.0MB  36701.0MB  35207.0MB

Adamw: 49027.0MB 48985.0MB 48991.0MB 48439.0MB

Could you share more training details on your side? It would help us debug. Thanks a lot!

980202006 commented 3 weeks ago

@zyushun Thank you! I will try again.

zyushun commented 2 weeks ago

Hi @aditya2331 , regarding " switching to adamw on your codebase runs into an OOM error, but using the original nanoGPT runs fine."

I found it might due to different attention implementation between our model.py (which is actually an old version from nanoGPT) and the latest model.py in nanoGPT. The difference lies in "CausalSelfAttention": the old version uses separated Q, K, V while the latest one combines them together. Mathematically, these two implementations should be equivalent, but computationally, they might lead to different memory consumption.

We recommend using our model.py to avoid any potential unexpected error.