NVIDIA / Megatron-LM

Ongoing research training transformer models at scale
https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/index.html#quick-start
Other
10.06k stars 2.27k forks source link

[BUG] when use --use-mcore-models and --overlap-param-gather bug #950

Open Kingsleyandher opened 1 month ago

Kingsleyandher commented 1 month ago

Describe the bug When the sequence of calculation parameters (FP16/BF16) in the buffer is different from the forward execution sequence of the model: As a result, when the --overlap-param-gather command is executed in the distributed optimizer scenario, the all-gather update sequence of buckets is inconsistent with the forward execution sequence.

As a result, when some parameters are executed in the forward direction, the bucket where the parameters are located has not been updated through all-gather. The weight of the previous step is still used when the parameter is calculated in the forward direction.

It is assumed that in a most extreme case, parameters A, B, and C are respectively in a bucket 1, a bucket 2, and a bucket 3, but an execution sequence of forward computation is A, C, and B. In this case, the value of all_gather_handle is None during the forward calculation of C.

https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/optimizer/distrib_optimizer.py#L1413

To Reproduce you can use DP=4, and add or not add --overlap-param-gather to train 100 steps, and check loss.

GPT_ARGS="
    --tensor-model-parallel-size 1 \
    --pipeline-model-parallel-size 1 \
    --overlap-grad-reduce \
    --use-mcore-models \
    --use-distributed-optimizer \
    --sequence-parallel \
    --num-layers 1 \
    --hidden-size 8192   \
    --num-attention-heads 64 \
    --seq-length 2048 \
    --max-position-embeddings 2048 \
    --micro-batch-size 1 \
    --global-batch-size 4 \
    --train-iters 1000 \
    --lr-decay-iters 320000 \
    --lr 5.0e-7 \
    --lr-decay-style cosine \
    --clip-grad 1.0 \
    --weight-decay 0.1 \
    --adam-beta1 0.9 \
    --adam-beta2 0.95 \
    --init-method-std 0.006 \
    --no-gradient-accumulation-fusion \
    --use-flash-attn \
    --disable-bias-linear \
    --position-embedding-type rope \
    --attention-dropout 0.0 \
    --hidden-dropout 0.0 \
    --bf16
"
deepakn94 commented 1 month ago

Hmm, the above script is non-deterministic (e.g., because of --use-flash-attn), so I wouldn't expect two executions of the same script to produce identical losses (if that is what you were expecting).

Moreover, even with deterministic execution, it is expected that the logged losses will look different with and without --overlap-param-gather. In particular, if you set --log-interval 1, you should see the losses are identical with an offset of 1 (please also use the --determistic-mode flag).