NVIDIA / Megatron-LM

Ongoing research training transformer models at scale
https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/index.html#quick-start
Other
10.18k stars 2.29k forks source link

[REGRESSION] MoEs are obtaining higher loss than they should during training #894

Closed kiddyboots216 closed 1 month ago

kiddyboots216 commented 3 months ago

Describe the regression In the forks of Megatron-LM used by gpt-neox and megatron-deepspeed, MoEs are obtaining lower loss than they are in Megatron-LM with the same configuration.

To Reproduce Attached to this issue are config files to reproduce the exact MoE we are running. megatron_125M_k1_e16_moe_3e-4_config.sh is the MoE config for megatron-lm, megatron_dense_125M_config.sh is the dense config for megatron-lm, gpt-neox_e16_k1_config.yaml is the MoE config for gpt-neox. All models are gpt-style. megatron_dense_125M_config.txt megatron_125M_k1_e16_moe_3e-4_config.txt gpt-neox_e16_k1_config.txt

Previous performance After step 12000 in gpt-neox the MoE has training loss 2.452.

New performance After step 12000 in megatron-lm the MoE has training loss 2.649 which is the same as the dense model.

Stack trace/logs The logs are attached to this issue.

Environment (please complete the following information):

Proposed fix No proposed fix.

Additional context Presumably a bug was introduced in the MoE training. However, I looked into the gpt code in mcore.models and was unable to find any potential causes.

bentherien commented 3 months ago

Here is some additional information:

The figure below shows validation loss curves for 125M MoE and dense models trained with megatron_125M_k1_e16_moe_3e-4_config.txt (maxLR \in {1e-4. 3e-4, 6e-4, 9e-4}) and megatron_dense_125M_config.txt, respectively. We observe that the MoE models underperform the dense models contrary to results form the literature. As mentioned above, this suggests that there is a bug in Megatron-LM

e16-k1-moe-maxlr3e-4_vs_dense-maxlr3e-4

Moreover, below is a plot directly comparing the training loss of dense and MoE models in Megatron and GPT-NeoX trained using GBS=768, SL=2048, E=16 (total exps), K=1 (active exps). All models are trained using the same dataset and the same linear warmup+consine annealing LRS (maxLR3e-4 to minLR3e-5). We observe that the GPT-NeoX implementation has results in line with the literature (e.g., switch transformer Figure1 right), while the Megatron implementation does not.

This suggests there is a bug in Megatron-LM @jaredcasper @duncanriach @jon-barker

megatron_and_neox_comparison

kiddyboots216 commented 3 months ago
125M_exps8-val-loss

Here is the validation loss plot for more MoE configs, again with varying LRs that all underperform the dense model.

yqli2420 commented 3 months ago

Here is some additional information:

The figure below shows validation loss curves for 125M MoE and dense models trained with megatron_125M_k1_e16_moe_3e-4_config.txt (maxLR \in {1e-4. 3e-4, 6e-4, 9e-4}) and megatron_dense_125M_config.txt, respectively. We observe that the MoE models underperform the dense models contrary to results form the literature. As mentioned above, this suggests that there is a bug in Megatron-LM

e16-k1-moe-maxlr3e-4_vs_dense-maxlr3e-4

Personally, I think we should change the horizontal axis to FLOPs and then compare the loss.

kiddyboots216 commented 3 months ago

Here is some additional information:

The figure below shows validation loss curves for 125M MoE and dense models trained with megatron_125M_k1_e16_moe_3e-4_config.txt (maxLR \in {1e-4. 3e-4, 6e-4, 9e-4}) and megatron_dense_125M_config.txt, respectively. We observe that the MoE models underperform the dense models contrary to results form the literature. As mentioned above, this suggests that there is a bug in Megatron-LM

e16-k1-moe-maxlr3e-4_vs_dense-maxlr3e-4

Personally, I think we should change the horizontal axis to FLOPs and then compare the loss.

These MoEs are all K=1, so they are already FLOPS-matched (in other words the plots would be the same if we changed the horizontal axis to FLOPS.)

zainsarwar865 commented 3 months ago

Running the same config with Megatron-DeepSpeed does result in the MoE outperforming the dense model. This was run with 8 experts, topk=1 and a 125M base model.

Deepspeed_moe_dense_e-8_topk1
yanring commented 3 months ago

Thank you for reporting the issue! We will investigate it and get back to you soon.

bentherien commented 3 months ago

Thank you @yanring !

yanring commented 3 months ago

Hi @kiddyboots216 @bentherien ,

We have done some investigations and discovered that the issue specifically pertains to the Top-1 selection, and the root cause is the ordering of softmax and topk. In short, we should apply softmax before selecting the top-k if k equals 1, since performing softmax on [num_tokens, 1] would result in a gradient of 0. Below is our experiments and code changes:

image image
zainsarwar865 commented 3 months ago

Thanks @yanring . I did observe that with top-k =2, the results were significantly better than what the literature suggests in terms of being better than top-k 1.

yanring commented 2 months ago

FYI: We have added the argument --moe-router-pre-softmax on commit

Thank you all!