NVIDIA / Megatron-LM

Ongoing research training transformer models at scale
https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/index.html#quick-start
Other
9.23k stars 2.08k forks source link

[BUG] Wrong embedding gradients with distributed optimizer and shared embedding #844

Closed li-plus closed 1 month ago

li-plus commented 1 month ago

Describe the bug When use_distributed_optimizer is enabled for models with share_embeddings_and_output_weights such as GPT2, all model gradients are reduce-scattered across DP ranks before the embedding gradients are all-reduced across PP[0] & PP[-1]. See https://github.com/NVIDIA/Megatron-LM/blob/0650d8335d45162845398a97880374b81c4d84b1/megatron/core/distributed/finalize_model_grads.py#L99-L150

Note that the wte gradients and lm_head gradients lie in different partitions of the contiguous gradient buffer (wte is the first weight on PP[0], lm_head is the last weight on PP[-1]), so they will be reduce-scattered to different DP ranks. The following embedding gradients all-reduce across PP[0] and PP[-1] within same DP group will add up partial results, producing wrong embedding gradients.

For example, consider only embedding gradients with dp=2 and pp=2 on 4 GPUs:

  1. before reduce-scatter across DP ranks:
pp \ dp 0 1
0 g0 g1
1 g2 g3
  1. after reduce-scatter across DP ranks:
pp \ dp 0 1
0 g0 (g0+g1)/2
1 (g2+g3)/2 g3
  1. after all-reduce embedding grad across PP[0] & PP[-1]:
pp \ dp 0 1
0 g0+(g2+g3)/2 g3+(g0+g1)/2
1 g0+(g2+g3)/2 g3+(g0+g1)/2

Embedding gradients on rank1 (pp0, dp1) and rank2 (pp1, dp0) are used in optimizer to update weights. They are expected to be the same, but they are not in fact.

To Reproduce Run pretrain_gpt.py with pp=2 and dp=2 on 4 local GPUs. Before returning from finalize_model_grads, print wte gradient hash on PP[0] and lm_head gradient hash on PP[-1].

    if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
        print(f'[Rank {torch.distributed.get_rank()}] embedding grad hash {model[0].module.module.language_model.embedding.word_embeddings.weight.main_grad.sum()}')
    elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
        print(f'[Rank {torch.distributed.get_rank()}] embedding grad hash {model[-1].module.module.word_embeddings.weight.main_grad.sum()}')

It can be observed that gradients on rank1 and rank2 are apparently different.

Expected behavior Correct embedding gradients with distributed optimizer for models with tied embeddings.

Environment (please complete the following information):

Proposed fix Move _allreduce_embedding_grads before finish_grad_sync. Will open a PR soon. Expected embedding gradient flow:

  1. embedding gradients:
pp \ dp 0 1
0 g0 g1
1 g2 g3
  1. all-reduce embedding grad across PP[0] & PP[-1]:
pp \ dp 0 1
0 g0+g2 g1+g3
1 g0+g2 g1+g3
  1. reduce-scatter across DP ranks:
pp \ dp 0 1
0 g0+g2 (g0+g2+g1+g3)/2
1 (g0+g2+g1+g3)/2 g1+g3
deepakn94 commented 1 month ago

Does the following fix not work: https://github.com/NVIDIA/Megatron-LM/commit/daf000673726b7dee40c834f181f76703808b2fc?

In particular, these lines: https://github.com/NVIDIA/Megatron-LM/commit/daf000673726b7dee40c834f181f76703808b2fc#diff-703512d9cce575fe32a776ec738162312b6276de08ac4846a50f07e3903cfdacR239-R245.

li-plus commented 1 month ago

It works. Problem is I've been using Megatron released on Jan without separate bucket for shared embedding. Just switched to latest master and solved it. Thanks!

deepakn94 commented 1 month ago

Awesome, great to hear!