NVIDIA / Megatron-LM

Ongoing research training transformer models at scale
https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/index.html#quick-start
Other
9.91k stars 2.24k forks source link

[BUG] Permormance drop while training with MoE #716

Open Teng-xu opened 6 months ago

Teng-xu commented 6 months ago

Describe the bug During our training sessions utilizing Megatron's Mixture of Experts (MoE) layers, we observed a decline in performance occurring at specific steps, with this deterioration manifesting sporadically and inconsistently throughout the training process. We have also done some memory profiling and found that the execution time is predominantly occupied by the all gather and reduce scatter calls, accounting for 99% of the time spent during the low-performance steps. Hence, we seek insights into potential causes of this performance issue.

To Reproduce

config = TransformerConfig(
    tensor_model_parallel_size=1, context_parallel_size=1, pipeline_model_parallel_size=1,
    expert_model_parallel_size=2, num_layers=32, hidden_size=4096, 
    num_attention_heads=32, layernorm_epsilon=1e-05,
    add_bias_linear=False, activation_func=F.silu, num_moe_experts=8,
    fp8=None, normalization='RMSNorm', moe_router_load_balancing_type='sinkhorn',
    moe_router_topk=2, moe_grouped_gemm=True, moe_aux_loss_coeff=0.0,
    moe_z_loss_coeff=None, moe_input_jitter_eps=None, moe_token_dropping=False
)

Expected behavior The throughput should be stable across the training steps.

Stack trace/logs

Batch 138 Loss: 6.425852298736572, Speed: 6.26 samples/sec, Model TFLOPS/GPU: 320.16
Batch 139 Loss: 6.429311275482178, Speed: 6.30 samples/sec, Model TFLOPS/GPU: 321.91
Batch 140 Loss: 6.296842575073242, Speed: 0.05 samples/sec, Model TFLOPS/GPU: 2.39
Batch 141 Loss: 6.297295570373535, Speed: 0.26 samples/sec, Model TFLOPS/GPU: 13.24

image(1)

Environment (please complete the following information):

ktaebum commented 6 months ago

Also encountered the same problem with non-MoE model. I tried to run training job of Llama 13B model on two DGX A100 nodes, but the time breakdown shows:

    forward-backward ...............................: (5662.82, 5666.72)
    forward-compute ................................: (2146.30, 2210.56)
    backward-compute ...............................: (3431.20, 3509.58)
    batch-generator ................................: (17.31, 33.45)
    layernorm-grads-all-reduce .....................: (5.24, 218.94)
    embedding-grads-all-reduce .....................: (0.06, 0.11)
    all-grads-sync .................................: (215891.91, 225072.22)
    optimizer-copy-to-main-grad ....................: (9.13, 9.19)
    optimizer-unscale-and-check-inf ................: (9.69, 9.88)
    optimizer-clip-main-grad .......................: (14.55, 14.77)
    optimizer-count-zeros ..........................: (0.02, 0.07)
    optimizer-inner-step ...........................: (31.58, 32.33)
    optimizer-copy-main-to-model-params ............: (9.36, 9.57)
    optimizer ......................................: (77.15, 77.37)

(I disabled all overlap-* optimizations and distributed-optimizer for more accurate time breakdown. Gradient AllReduce takes more than 200 seconds while forward-backward takes just 5.6 seconds. The problem occurs regardless of using distributed-optimizer:

    forward-backward ...............................: (6640.79, 6647.08)
    forward-compute ................................: (3118.90, 3181.81)
    backward-compute ...............................: (3428.96, 3512.83)
    batch-generator ................................: (16.72, 34.26)
    layernorm-grads-all-reduce .....................: (4.97, 11.08)
    embedding-grads-all-reduce .....................: (0.06, 0.12)
    all-grads-sync .................................: (77025.69, 112368.28)
    params-all-gather ..............................: (77461.61, 112343.53)
    optimizer-copy-to-main-grad ....................: (4.65, 4.82)
    optimizer-unscale-and-check-inf ................: (5.37, 5.39)
    optimizer-clip-main-grad .......................: (7.70, 7.74)
    optimizer-count-zeros ..........................: (0.02, 0.03)
    optimizer-inner-step ...........................: (15.89, 16.30)
    optimizer-copy-main-to-model-params ............: (4.53, 4.56)
    optimizer ......................................: (77502.36, 112384.28)

When I run the same job on a single node, the problem disappears.

My environment is

ktaebum commented 6 months ago

My issue has been resolved by passing --device=/dev/infiniband in docker run argument.

rahul003 commented 6 months ago

ktaebum's issue is unrelated. We only notice slowdown in some steps, and due to intra node AllGather calls which are surprisingly high for those steps

dawson-chen commented 5 months ago

I have encountered the same problem with MoE, when route type is sinkhorn and topK > 1.

image

From my log, I found the main comsumption is from sinkhorn function

norm_logits = sinkhorn(
                    logits.to(dtype=torch.float32)
                )
dawson-chen commented 5 months ago

When topk > 1 and route type is sinkhorn, the sinkhorn function inner code loop thousands times for some logits cases.

But I didn't found any clue on those logits, look similar with normal ones.

yanring commented 5 months ago

@Teng-xu @dawson-chen Thanks for reporting this issue. This could be due to too many iterations in Sinkhorn on some ranks. You can try adding an early stop to Sinkhorn or using aux_loss for load balancing.

wen020 commented 3 months ago

how to get Model TFLOPS/GPU?

github-actions[bot] commented 1 month ago

Marking as stale. No activity in 60 days.