meta-llama / llama-models

Utilities intended for use with Llama models.
Other
4.96k stars 853 forks source link

A doubt about training infrastructure: Why more micro-batches can hide P2P communications? #78

Open eileenzhujuan opened 4 months ago

eileenzhujuan commented 4 months ago

Thank you for the detailed report for Llama3.1, which is very inspirational. I read the report and have a doubt about training infrastructure. In chapter 3.3.2 titled Parallelism for Model Scaling. there is a sentence said "more micro-batches to hide point-to-point communication". The screenshot is below. I cannot figure out why more micro-batches can hide P2P communication? Can somebody help me and explain this? image

AnimeshBote commented 3 months ago

In the paper, they have mentioned that they are using async point to point communication in PP, so by increasing number of micro-batches they are trying to overlap the time spent in point-to-point communication with processing some micro-batch. So basically the GPU doesn't waste time in waiting while the network call is being made to distribute the gradient across nodes. I think this is the meaning that they are trying to convey.

eileenzhujuan commented 3 months ago

In the paper, they have mentioned that they are using async point to point communication in PP, so by increasing number of micro-batches they are trying to overlap the time spent in point-to-point communication with processing some micro-batch. So basically the GPU doesn't waste time in waiting while the network call is being made to distribute the gradient across nodes. I think this is the meaning that they are trying to convey.

@AnimeshBote Thank you for the reply. I seem to understand that with more micro-batches and the more contiguous micro-batches are executed , the ratio of overlapped P2P get larger. However, I didn't get the idea of the second sentence, "the GPU doesn't waste time in waiting while the network call is being made to distributed the gradient across nodes." What's the relationship with gradient distribution. And by saying "gradient distribution", do you mean the gradient averaging process brought by data parallel?