pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.11k stars 386 forks source link

Question about torchtune's Low Memory Footprint #1758

Open nappaillav opened 4 days ago

nappaillav commented 4 days ago

First of all, thank you for developing torchtune. This has been very helpful for our group with limited GPU credits. I'm impressed by its capabilities, particularly its memory efficiency. I've noticed that torchtune achieves a significantly lower memory footprint compared to the same model using Hugging Face's TRL/Transformers library. For example: Full fine-tuning of LLaMA 3B with Hugging Face libraries takes about 35GB on an A100 GPU. The same model with torchtune uses less than 20GB. Reference Question I'm very curious about how torchtune achieves this impressive memory efficiency. Could you provide some insights into the specific changes or techniques that make this possible? (specifically for the full fine tuning) Some areas I'm particularly interested in:

ebsmothers commented 3 days ago

Hi @nappaillav thanks for creating the issue and glad to hear that you've found torchtune useful. There are a number of different memory-saving techniques we use, and the numbers from the table you referenced should be pretty well-represented by this config (though as mentioned in the readme we are setting batch size 2, seq len 2048 with packed dataset, and no gradient accumulation for all those runs). Looking in the config, you can see a couple different things:

1) We use bitsandbytes's PagedAdamW8bit optimizer to store optimizer states in reduced precision. The paging bit is important if you're running out of GPU memory, would recommend reading Tim's great comment for more details on this. 2) We fuse the optimizer step with backward, so whenever a parameter's gradient is ready we immediately update it and discard the gradient.

We also use full bf16 training and activation checkpointing (we checkpoint every transformer layer, it is also possible to checkpoint less frequently if you want to get faster training but at the expense of more memory). But these are pretty standard techniques anyways.

There are a couple other things we're doing to get those numbers that aren't explicitly represented in the default config:

3) For large-vocab-size models, there can be a pretty big memory spike at the output projection and cross-entropy loss calculation. We chunk the computation by default, which saves several GiB (see this PR). 4) We use torch.compile (so set compile=True). This will fuse various operations and can save several GiB over running without compile (it'll also make training faster). We compile each transformer layer separately (this gives faster compile time), and we also compile the cross-entropy loss.

I realize I didn't actually explicitly respond to any of your questions but I hope this answers some of them. Will respond to a couple specific things as well:

How does torchtune handle gradient accumulation or optimization differently?

I discussed optimization above, but we don't really do much different with respect to gradient accumulation.

Are there any trade-offs made to achieve this memory efficiency?

There are always trade-offs 😃. Most memory optimizations usually come at the expense of slower training time or reduced model quality. Though some (like torch.compile) don't really have such tradeoffs. We do not enable features that will negatively impact model quality, but if you have more memory available you can disable some of these memory-saving techniques to get faster training.

In that vein, I would refer you to our tutorial page here. It goes through a lot of the memory-saving techniques I've described, along with some other ones I haven't. Feel free to let me know if you have any follow-up questions.