sail-sg / zero-bubble-pipeline-parallelism

Zero Bubble Pipeline Parallelism
Other
215 stars 9 forks source link

[QUESTION] The timing for B and W appears to be incorrect #22

Open RookieHong opened 1 month ago

RookieHong commented 1 month ago

Your question It seems that B's timing includes W, while W merely accounts for the time of gradient accumulation.

In the megatron/core/pipeline_parallel/zb_schedules.py file, the function schedule_b counts the duration of this:

input_tensor_grad = backward_step(
    input_tensor, output_tensor, output_tensor_grad, self.model_type,
    self.config
)

This is actually B+W; it computes the gradients with respect to the inputs and the weights.

While, in the schedule_w function, W counts the duration of this:

WeightGradStore.pop(chunk=scheduled_node.chunk)

After conducting a global search for WeightGradStore.put, I found that it actually only puts operations for gradient accumulation, specifically fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32 or fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16. Therefore, W actually counts only the operation time for gradient accumulation!

The timing statistics from the timer also prove this point, with W's duration being very short and B's duration being almost double that of F:

image

Is this the expected result?

ufotalent commented 1 month ago

Hi, thanks for the interest in our work and implementation.

The wgrad_gemm_accum_fp16 actually do both weight grad calculation and weight accumulation. It's a fused kernel of weight grad calculation and accumulation, that's why it's called gemm + accum.

Just to make sure, is gradient_accumulation_fusion enabled for your setting? Our implementation of B-W split only works when gradient_accumulation_fusion enabled.

RookieHong commented 1 month ago

Thanks for the reply!

I did not add the --no-gradient-accumulation-fusion parameter and get_args().gradient_accumulation_fusion is True when running. However, W has such a short runtime, while B takes almost twice as long as F. I wonder if there could be other reasons for this?