databricks / megablocks

Apache License 2.0
1.11k stars 154 forks source link

Gradient scale size for expert gradient #86

Closed fanshiqing closed 5 months ago

fanshiqing commented 5 months ago

Hi @tgale96 , one naive question on gradient scale for expert weight: in the current implementation, we will scale the moe weight by 1/expert_parallel_world_size (src), my question is: (1) how to understand this behavior in a nature way? (2) if expert parallel(sharding on num_expert dim) and tensor parallel (sharding on ffn_hidden_state dim) both enabled, then each expert gradient will by actually scale by 1/(ep_size*dp_size), which seems a bit strange compared to the traditional gradient scale for DDP.

tgale96 commented 5 months ago

Hi! This is a great question! I've been meaning to write this down because it keeps coming up and the explanation is non-trivial. Let me give it a go here...

(1) how to understand this behavior in a nature way?

If you're running with tokens tokens per device and experts experts on N devices with pure data parallelism (no expert model parallelism) each device calculates the average gradient for its tokens and then averages over the N devices, so your gradient is effectively scaled by 1 / (tokens * N) for the expert weights.

If you're instead running with N-way expert model parallelism, there is no final gradient all reduce for the expert weights so your gradient is scaled by 1 / tokens. We scale by 1 / expert_parallel_world_size = 1 / N to correct this so that the two settings match.

(2) if expert parallel(sharding on num_expert dim) and tensor parallel (sharding on ffn_hidden_state dim) both enabled, then each expert gradient will by actually scale by 1/(ep_size*dp_size), which seems a bit strange compared to the traditional gradient scale for DDP.

I think the above explanation covers this as well? We need the additional scale factor to match data parallel semantics. Lmk if this makes sense! I know it is subtle and I want to make sure we can explain it (and that it is correct, obviously:))!

fanshiqing commented 5 months ago

Thanks for the detailed clarification! While I'm still a bit confused. Here is what's my understanding.

The gradient scaling works differently from a scenario with DP and TP.

For a same model training with 8 ranks, let's say there are four parallel mapping configs: (1) 8 DP ranks; (2) 4 DP rank and 2 TP ranks; (3) 2 DP rank and 4 TP ranks; (4) 8 TP ranks;

Let's break down each of this configurations and check what's corresponded scaling factor: (1) 8 DP Ranks:

(2) 4 DP Ranks and 2 TP Ranks:

(4) 8 TP Ranks:

fanshiqing commented 5 months ago

Oh I come to understand this issue now, thanks again!

tgale96 commented 5 months ago

No problem! I think one of the key bits of information that I left out of my original post is that this scaling is basically replacing the scaling factor that is applied when you do the all-reduce + scale to compute the average gradients across ranks with data parallelism.

Each rank gets a factor of 1 / tokens from the loss function (e.g., average cross entropy loss per token) and then a factor of 1 / N is applied during DDP backwards before the all reduce. With EP we get the former but not the latter, which is why we apply this correction.

On the topic of expert weight gradients, I should add that I've seen at least one paper adjust the expert weight gradients by an additional scale factor to account for the effectively lower batch size seen by each expert. I've been wanting to try this and to do so we would simply adjust the scale factor we're already using to preserve semantics to additionally apply their correction.