pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
BSD 3-Clause "New" or "Revised" License
3.55k stars 291 forks source link

Trying to understand batch size with FSDP #1136

Closed ScottHoang closed 1 day ago

ScottHoang commented 3 days ago

From the comments in Lora_distributed_finetuning, batch size = num_accumulated_gradients batch_size nproc_per_node... But shouldn't it be just batch_size * num_accumulated_gradients?

RdoubleA commented 3 days ago

Hi @ScottHoang, good question. Since FSDP also includes data parallelism, the "global batch" is sharded across GPUs/processes so each device sees a different sliver of the batch. When you specify batch_size, you are specifying local batch size. So to account for data parallel AND gradient accumulation, true global batch size = num_accumulated_gradients local_batch_size nproc_per_node.

Note that this only applies to distributed recipes. For single device recipes, nproc_per_node = 1 and global batch size is simply batch_size * num_accumulated_gradients

ScottHoang commented 3 days ago

Thank you so much for your detailed response. Is The reason why we shard the global batch to fully utilize all available GPUs instead of waiting for the intermediate values from another shard?

RdoubleA commented 3 days ago

Not exactly, it's more of being able to train on a larger effective batch size. So if I have 8 GPUs but they can individually only hold up to a max of BS=2, I can use data parallelism to get an effective global batch size of 16 and significantly increase throughput. If I didn't have data parallelism, then I would be stuck with an effective global batch size of 2 that is replicated across 8 GPUs. But I remove the communication overhead that comes with any form of parallelism. So there's tradeoffs.

waiting for the intermediate values from another shard

This is true for any form of parallelism, data, model etc, where you have to keep activations, gradients, etc in sync across devices.

ScottHoang commented 1 day ago

Thank you. This clears up a lot!