Open pluiez opened 10 months ago
This can be reproduced by cloning latest Megatron-LM and enabling transformer_engine for --transformer-impl instead of using local implementation.
--transformer-impl
The experiments are run in a nvcr.io/nvidia/pytorch:23.11-py3 container with 8 H800 GPUs.
nvcr.io/nvidia/pytorch:23.11-py3
This is caused by the float32 operations in apply_rotary_pos_emb.
Training log for original apply_rotary_pos_emb implementation:
apply_rotary_pos_emb
Specifically, this can be modified to use lower-precision by changing line 820 from apply_rotary_pos_emb to the following:
t = (t * freqs.cos().to(t.dtype)) + (_rotate_half(t) * freqs.sin().to(t.dtype))
Here is the training log after modification:
The following training command is used:
torchrun --nproc_per_node 8 --nnodes 1 --node_rank 0 --master_addr localhost --master_port 9260 \ pretrain_gpt.py \ --use-flash-attn \ --bf16 \ --transformer-impl transformer_engine \ --use-distributed-optimizer \ --num-layers 24 \ --hidden-size 2048 \ --num-attention-heads 16 \ --seq-length 2048 \ --max-position-embeddings 2048 \ --micro-batch-size 4 \ --global-batch-size 512 \ --lr 3e-4 \ --train-iters 10000 \ --lr-decay-iters 10000 \ --lr-decay-style cosine \ --lr-warmup-fraction .01 \ --min-lr 3e-5 \ --weight-decay 0.1 \ --clip-grad 1.0 \ --swiglu \ --normalization RMSNorm \ --disable-bias-linear \ --attention-dropout 0.0 \ --hidden-dropout 0.0 \ --untie-embeddings-and-output-weights \ --use-rotary-position-embeddings \ --no-position-embedding \ --no-masked-softmax-fusion \ --eod-mask-loss \ --data-path 1.0 /PATH/TO/DATA \ --tokenizer-model /PATH/TO/TOKENIZER \ --tokenizer-type GPTSentencePieceTokenizer \ --log-interval 1 \ --log-throughput \ --distributed-backend nccl
PR #517 should significantly help here.
This can be reproduced by cloning latest Megatron-LM and enabling transformer_engine for
--transformer-impl
instead of using local implementation.The experiments are run in a
nvcr.io/nvidia/pytorch:23.11-py3
container with 8 H800 GPUs.This is caused by the float32 operations in apply_rotary_pos_emb.
Training log for original
apply_rotary_pos_emb
implementation:Specifically, this can be modified to use lower-precision by changing line 820 from apply_rotary_pos_emb to the following:
Here is the training log after modification:
The following training command is used: