pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
3.92k stars 353 forks source link

Very High Ram Usage when resuming training from unsharded recipe_state #1318

Open bachvudinh opened 1 month ago

bachvudinh commented 1 month ago

I've encountered a small issue while using Torchtune for distributed training across multiple GPUs. The problem occurs when resuming training from an unsharded recipe_state, resulting in extremely high RAM usage (over 900GB).

Current Setup:

When resuming training, the loading of the unsharded recipe_state causes excessive RAM consumption. This high memory usage makes it challenging to resume training efficiently, especially on multi-node setups. I've implemented a temporary fix by adding a time delay for each GPU based on its rank_value. After finishing, these GPUs wait for others to complete. While this mitigates the immediate problem, it's not an optimal long-term solution.

So i propose:

joecummings commented 4 weeks ago

I think adding support for Distributed Checkpointer would address this - can someone from DCP confirm? @LucasLLC @pbontrager

DCP is definitely on our radar as something we want to integrate.

As a side note, I will admit our testing of multi-node setups is very limited so far. We've mainly stuck to single node. It might be worth checking out how torchtitan is handling checkpointing as their charter is explicitly to handle training massively parallel.