Closed li-plus closed 1 month ago
Does the following fix not work: https://github.com/NVIDIA/Megatron-LM/commit/daf000673726b7dee40c834f181f76703808b2fc?
In particular, these lines: https://github.com/NVIDIA/Megatron-LM/commit/daf000673726b7dee40c834f181f76703808b2fc#diff-703512d9cce575fe32a776ec738162312b6276de08ac4846a50f07e3903cfdacR239-R245.
It works. Problem is I've been using Megatron released on Jan without separate bucket for shared embedding. Just switched to latest master and solved it. Thanks!
Awesome, great to hear!
Describe the bug When
use_distributed_optimizer
is enabled for models withshare_embeddings_and_output_weights
such as GPT2, all model gradients are reduce-scattered across DP ranks before the embedding gradients are all-reduced across PP[0] & PP[-1]. See https://github.com/NVIDIA/Megatron-LM/blob/0650d8335d45162845398a97880374b81c4d84b1/megatron/core/distributed/finalize_model_grads.py#L99-L150Note that the wte gradients and lm_head gradients lie in different partitions of the contiguous gradient buffer (wte is the first weight on PP[0], lm_head is the last weight on PP[-1]), so they will be reduce-scattered to different DP ranks. The following embedding gradients all-reduce across PP[0] and PP[-1] within same DP group will add up partial results, producing wrong embedding gradients.
For example, consider only embedding gradients with dp=2 and pp=2 on 4 GPUs:
Embedding gradients on rank1 (pp0, dp1) and rank2 (pp1, dp0) are used in optimizer to update weights. They are expected to be the same, but they are not in fact.
To Reproduce Run
pretrain_gpt.py
with pp=2 and dp=2 on 4 local GPUs. Before returning fromfinalize_model_grads
, printwte
gradient hash on PP[0] andlm_head
gradient hash on PP[-1].It can be observed that gradients on rank1 and rank2 are apparently different.
Expected behavior Correct embedding gradients with distributed optimizer for models with tied embeddings.
Environment (please complete the following information):
Proposed fix Move
_allreduce_embedding_grads
beforefinish_grad_sync
. Will open a PR soon. Expected embedding gradient flow: