NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.87k stars 308 forks source link

`apply_rotary_pos_emb` significantly hurts training efficiency #552

Open pluiez opened 10 months ago

pluiez commented 10 months ago

This can be reproduced by cloning latest Megatron-LM and enabling transformer_engine for --transformer-impl instead of using local implementation.

The experiments are run in a nvcr.io/nvidia/pytorch:23.11-py3 container with 8 H800 GPUs.

This is caused by the float32 operations in apply_rotary_pos_emb.

Training log for original apply_rotary_pos_emb implementation: image

Specifically, this can be modified to use lower-precision by changing line 820 from apply_rotary_pos_emb to the following:

t = (t * freqs.cos().to(t.dtype)) + (_rotate_half(t) * freqs.sin().to(t.dtype))

Here is the training log after modification: image

The following training command is used:

torchrun --nproc_per_node 8 --nnodes 1 --node_rank 0 --master_addr localhost --master_port 9260 \
pretrain_gpt.py \
--use-flash-attn \
--bf16 \
--transformer-impl transformer_engine \
--use-distributed-optimizer \
--num-layers 24 \
--hidden-size 2048 \
--num-attention-heads 16 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 4 \
--global-batch-size 512 \
--lr 3e-4 \
--train-iters 10000 \
--lr-decay-iters 10000 \
--lr-decay-style cosine \
--lr-warmup-fraction .01 \
--min-lr 3e-5 \
--weight-decay 0.1 \
--clip-grad 1.0 \
--swiglu \
--normalization RMSNorm \
--disable-bias-linear \
--attention-dropout 0.0 \
--hidden-dropout 0.0 \
--untie-embeddings-and-output-weights \
--use-rotary-position-embeddings \
--no-position-embedding \
--no-masked-softmax-fusion \
--eod-mask-loss \
--data-path 1.0 /PATH/TO/DATA \
--tokenizer-model /PATH/TO/TOKENIZER \
--tokenizer-type GPTSentencePieceTokenizer \
--log-interval 1 \
--log-throughput \
--distributed-backend nccl
ptrendx commented 10 months ago

PR #517 should significantly help here.