sail-sg / zero-bubble-pipeline-parallelism

Zero Bubble Pipeline Parallelism
Other
287 stars 14 forks source link

interleaved 1F1B seems to work better #21

Open zhj96 opened 6 months ago

zhj96 commented 6 months ago

I tried multiple sets of experiments, but found that ZB is better than 1F1B. Interleaved 1F1B seems to be slightly faster than ZB_V, slightly slower than ZB_2P but saves a lot of GPU memory.

machine: 8*H800 80G model:6.2B

1F1B ​55 samples/(8 GPU)/seconds 48G MEM ​INTERLEAVED_1F1B 66 samples/(8 GPU)/seconds 57G MEM ZB_2P 67 samples/(8 GPU)/seconds 79G MEM ZB_V 64 samples/(8 GPU)/seconds 53G MEM

ufotalent commented 6 months ago

Hi, can you share the settings of the pipeline parallelism? Like the degree of PP, the number of microbatches (can be calculated by globalbs/microbs/DP) and the number of virtual stages of interleaved_1f1b?

One possible reason is the number of microbatches or the number of virtual stages in interleaved 1f1b is very high so interleaved 1f1b is already having a good efficiency.

zhj96 commented 6 months ago

Hi, can you share the settings of the pipeline parallelism? Like the degree of PP, the number of microbatches (can be calculated by globalbs/microbs/DP) and the number of virtual stages of interleaved_1f1b?

One possible reason is the number of microbatches or the number of virtual stages in interleaved 1f1b is very high so interleaved 1f1b is already having a good efficiency.

Hi, can you share the settings of the pipeline parallelism? Like the degree of PP, the number of microbatches (can be calculated by globalbs/microbs/DP) and the number of virtual stages of interleaved_1f1b?

One possible reason is the number of microbatches or the number of virtual stages in interleaved 1f1b is very high so interleaved 1f1b is already having a good efficiency.

Yes, I think Zero Bubble needs to be compared with the optimal performance of interleaved 1f1b to show its value, so the virtual stages are set to the maximum value.

When the number of microbatches is small, ZB and interleaved_1f1b will be better than 1F1B. When the number of microbatches is large enough(eg: 512), the difference between the three is almost negligible(The performance difference is less than 1%).In actual pre-training, ZeroBubble has very little improvement than 1F1B.

Did I make a mistake somewhere?Did I make a mistake somewhere?

ufotalent commented 6 months ago

Hi, can you share the settings of the pipeline parallelism? Like the degree of PP, the number of microbatches (can be calculated by globalbs/microbs/DP) and the number of virtual stages of interleaved_1f1b? One possible reason is the number of microbatches or the number of virtual stages in interleaved 1f1b is very high so interleaved 1f1b is already having a good efficiency.

Hi, can you share the settings of the pipeline parallelism? Like the degree of PP, the number of microbatches (can be calculated by globalbs/microbs/DP) and the number of virtual stages of interleaved_1f1b? One possible reason is the number of microbatches or the number of virtual stages in interleaved 1f1b is very high so interleaved 1f1b is already having a good efficiency.

Yes, I think Zero Bubble needs to be compared with the optimal performance of interleaved 1f1b to show its value, so the virtual stages are set to the maximum value.

When the number of microbatches is small, ZB and interleaved_1f1b will be better than 1F1B. When the number of microbatches is large enough(eg: 512), the difference between the three is almost negligible(The performance difference is less than 1%).In actual pre-training, ZeroBubble has very little improvement than 1F1B.

Did I make a mistake somewhere?Did I make a mistake somewhere? The problem of pipeline bubble depends on your settings. Actually it depends on the number of microbatches divides number of pipelines. In a hybrid parallelism we have pp = GPUs/dp/tp; number of microbatches = global bs / dp / microbs. So the number of microbatches / pp =global bs * tp / GPUs / microbs. It will be too assertive to say the problem of pipeline bubble is negligible, it depends on your settings. In many practical LLM pretraining cases, the number of microbatches is capped by the global batch size which is capped by the optimizer algorithm. Usually the problem of bubble will be larger if you're training a bigger model with more GPUs.

Interleaved 1F1B is a good strategy. that being said, zero bubble is actually a method orthogonal to interleaving. We can easily produce a schedule that's both interleaving and zero bubble. One example is sth like below, which is combining ZB-H1 and interleaving. We leave this to the community to explore, because the main point of zero bubble is the idea of B-W split and its ability of bubble elimination, not saying it's better than everything else.

ZB-H1 + interleave image