NVIDIA / Megatron-LM

Ongoing research training transformer models at scale
https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/index.html#quick-start
Other
10.64k stars 2.39k forks source link

[REGRESSION] get_batch in pretrain_gpt.py is much slower than the old impl #678

Open aoyulong opened 10 months ago

aoyulong commented 10 months ago

Describe the regression The new code is much slower than the old one when increasing the sequence length.

To Reproduce The code was tested on a A800-80GB server equipped with NVLINK. The configuration is as following, and only the --seq-lengthand --max-position-embeddings are changed.

DISTRIBUTED_ARGS="
    --nproc_per_node 8
    --nnodes 1
    --node_rank 0
    --master_addr $MASTER_ADDR \
    --master_port $MASTER_PORT 
"

TRAINING_ARGS="
    --train-samples 10000 \
    --eval-iters 0 \
    --eval-interval 2000 \
    --tensor-model-parallel-size 4 \
    --pipeline-model-parallel-size 2 \
    --make-vocab-size-divisible-by 64 \
    --micro-batch-size 1 \
    --global-batch-size 4 \
    --disable-bias-linear \
    --sequence-parallel \
    --use-flash-attn
"

MIXED_PRECISION_ARGS="
    --bf16 \
    --attention-softmax-in-fp32 \
    --accumulate-allreduce-grads-in-fp32
"

DATA_ARGS="
    --data-path $DATA_PATH \
    --tokenizer-type GPT2BPETokenizer \
    --vocab-file $VOCAB_FILE \
    --vocab-size 100008\
    --merge-file $MERGE_FILE \
    --split 1
"
    # --special-tokens-file $SPECIAL_TOKENS_FILE \

NETWORK_ARGS="
    --num-layers 24 \
    --hidden-size 2048 \
    --num-attention-heads 16 \
    --seq-length 4096 \
    --max-position-embeddings 4096 \
    --normalization RMSNorm \
    --use-rotary-position-embeddings \
    --no-position-embedding \
    --swiglu \
    --untie-embeddings-and-output-weights
"

INITIALIZATION_ARGS="
    --init-method-std 0.02 \
    --seed 42
"

REGULARIZATION_ARGS="
    --attention-dropout 0.0 \
    --hidden-dropout 0.0 \
    --weight-decay 0.1 \
    --adam-beta1 0.9 \
    --adam-beta2 0.95 \
    --clip-grad 1.0
"

LEARNING_RATE_ARGS="
    --lr 1.0e-3 \
    --lr-decay-style cosine \
    --lr-warmup-samples 100 \
    --min-lr 1.0e-5
"

CHECKPOINTING_ARGS="
    --save-interval 20000 \
    --save $CHECKPOINT_PATH \
"

LOGGING_ARGS="
    --log-interval 1 \
    --timing-log-level 2 \
    --tensorboard-dir $TB_PATH \
    --tensorboard-log-interval 1 \
    --wandb-save-dir $WB_PATH
"

cmd="torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \
              $TRAINING_ARGS \
              $MIXED_PRECISION_ARGS \
              $DATA_ARGS \
              $NETWORK_ARGS \
              $INITIALIZATION_ARGS \
              $REGULARIZATION_ARGS \
              $LEARNING_RATE_ARGS \
              $CHECKPOINTING_ARGS \
              $LOGGING_ARGS
    "
echo $cmd
eval $cmd
Performance comparison commit id seq_len=4096 seq_len=8192 seq_len=16384 seq_len=32768
new: de4028a9d45bd65c67e1a201d9e0690bd6cb4304 360.1 789.8 2914.0 11820.5
old: 67a0e5df1a51461d707bf6609ce44993eaaee545 342.5 613.3 1273.1 OOM

Environment (please complete the following information):

Proposed fix Change the get_batch function to the original implementation.

github-actions[bot] commented 8 months ago

Marking as stale. No activity in 60 days.