Closed fanshiqing closed 5 months ago
Hi! This is a great question! I've been meaning to write this down because it keeps coming up and the explanation is non-trivial. Let me give it a go here...
(1) how to understand this behavior in a nature way?
If you're running with tokens
tokens per device and experts
experts on N
devices with pure data parallelism (no expert model parallelism) each device calculates the average gradient for its tokens
and then averages over the N
devices, so your gradient is effectively scaled by 1 / (tokens * N)
for the expert weights.
If you're instead running with N
-way expert model parallelism, there is no final gradient all reduce for the expert weights so your gradient is scaled by 1 / tokens
. We scale by 1 / expert_parallel_world_size = 1 / N
to correct this so that the two settings match.
(2) if expert parallel(sharding on num_expert dim) and tensor parallel (sharding on ffn_hidden_state dim) both enabled, then each expert gradient will by actually scale by
1/(ep_size*dp_size)
, which seems a bit strange compared to the traditional gradient scale for DDP.
I think the above explanation covers this as well? We need the additional scale factor to match data parallel semantics. Lmk if this makes sense! I know it is subtle and I want to make sure we can explain it (and that it is correct, obviously:))!
Thanks for the detailed clarification! While I'm still a bit confused. Here is what's my understanding.
The gradient scaling works differently from a scenario with DP and TP.
DP: In DP, each rank has a complete copy of the model and computes gradients based on a different subset of the data. The gradients are then averaged across all DP ranks. If you have 4 DP ranks, you would typically scale the gradients by 1/4.
TP: In TP, the model is split across different ranks. Each rank computes a portion of the forward and backward pass. The gradients in TP are typically not scaled by the number of TP ranks because the model is divided among these ranks, and they collectively compute the full gradient. No scaling is needed for this case.
For a same model training with 8 ranks, let's say there are four parallel mapping configs: (1) 8 DP ranks; (2) 4 DP rank and 2 TP ranks; (3) 2 DP rank and 4 TP ranks; (4) 8 TP ranks;
Let's break down each of this configurations and check what's corresponded scaling factor: (1) 8 DP Ranks:
(2) 4 DP Ranks and 2 TP Ranks:
(4) 8 TP Ranks:
Explanation: In this configuration, all ranks are used for TP, and there's no DP involved. TP ranks collectively compute the entire model's gradients, so there's no need to scale the gradients down as each part of the model's gradients is computed by a different rank.
Is there any going wrong so far?
Oh I come to understand this issue now, thanks again!
No problem! I think one of the key bits of information that I left out of my original post is that this scaling is basically replacing the scaling factor that is applied when you do the all-reduce + scale to compute the average gradients across ranks with data parallelism.
Each rank gets a factor of 1 / tokens
from the loss function (e.g., average cross entropy loss per token) and then a factor of 1 / N
is applied during DDP backwards before the all reduce. With EP we get the former but not the latter, which is why we apply this correction.
On the topic of expert weight gradients, I should add that I've seen at least one paper adjust the expert weight gradients by an additional scale factor to account for the effectively lower batch size seen by each expert. I've been wanting to try this and to do so we would simply adjust the scale factor we're already using to preserve semantics to additionally apply their correction.
Hi @tgale96 , one naive question on gradient scale for expert weight: in the current implementation, we will scale the moe weight by
1/expert_parallel_world_size
(src), my question is: (1) how to understand this behavior in a nature way? (2) if expert parallel(sharding on num_expert dim) and tensor parallel (sharding on ffn_hidden_state dim) both enabled, then each expert gradient will by actually scale by1/(ep_size*dp_size)
, which seems a bit strange compared to the traditional gradient scale for DDP.