google / paxml

Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
456 stars 68 forks source link

[Question] Very low MFU(30%~35%) when train bf16 Llama2 and GPT model with single SXM4 A100 machine. #65

Open MoFHeka opened 9 months ago

MoFHeka commented 9 months ago

I don't know what happened, is the calculation precision and parameter precision not set correctly? Deepspeed or Megatron could achieve 55% MFU easily with same machine. Here is my bash script:

#! /bin/bash
set -u
set -o pipefail

TFDS_DATA_DIR=$1
VOCAB_PATH=$2
PREC=${3:-"bfloat16"}        # Precision (float32, bfloat16)
NUM_GPUS=${4:-8}      # Number of GPUs (1, 2, 4, 8)
PERCORE_BATCH_SIZE=${5:-4}
LOG_DIR=${6:-"test_logdir"}

export VOCAB_PATH=$VOCAB_PATH

BASE_XLA_FLAGS=${BASE_XLA_FLAGS:-"--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false
                       --xla_gpu_simplify_all_fp_conversions --xla_gpu_enable_async_all_gather=true
                       --xla_gpu_enable_async_reduce_scatter=true  --xla_gpu_enable_highest_priority_async_stream=true
                       --xla_gpu_enable_triton_softmax_fusion=false  --xla_gpu_all_reduce_combine_threshold_bytes=51200
                       --xla_gpu_graph_level=3 --xla_gpu_enable_async_all_reduce=true
                       --xla_gpu_enable_async_collectives=true --xla_gpu_enable_async_collective_permute=true
                       --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true
                       --xla_gpu_enable_async_all_to_all=true --xla_gpu_all_reduce_contiguous=true
                       --xla_gpu_all_reduce_blueconnect_num_devices_per_host=true
                       --xla_gpu_enable_cudnn_frontend=true --xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true
                       --xla_gpu_enable_cudnn_layer_norm "}
export XLA_FLAGS="$BASE_XLA_FLAGS ${XLA_FLAGS:-}"

export ENABLE_TE=1

mkdir -p ${LOG_DIR}
python3 -u -m paxml.main \
    --job_log_dir=${LOG_DIR} \
    --fdl_config=paxml.tasks.lm.params.nvidia.Llama2_7B \
    --fdl.FPROP_DTYPE=\"${PREC}\" \
    --fdl.ICI_MESH_SHAPE="[1,$(expr ${NUM_GPUS}), 1]" \
    --fdl.DCN_MESH_SHAPE="[1,1,1]" \
    --fdl.NUM_STAGES=1 \
    --fdl.MICROBATCH_SIZE=$PERCORE_BATCH_SIZE \
    --fdl.PERCORE_BATCH_SIZE=$PERCORE_BATCH_SIZE \
    --tfds_data_dir=$TFDS_DATA_DIR \
    --alsologtostderr \
    2>&1 | tee ${LOG_DIR}/llama2_7B_output.log

EXP_STATUS=$?

if [ $EXP_STATUS != 0 ]; then
  echo "Run failed"
else
  echo "Run succeeded!"
fi

According https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax, Nvidia train a 5B GPT model with Nativ BF16 in 256 A100 GPU. And its performance 465.45 Sequences/Sec when sequences global batch size is 8*256=2048. So it means it costed 4.4s per step. Am I correct? This script could calculate its MFU which is 38.958427%. It's too low!

# Nvidia Jax GPT5B
card_num=256
gbs=8*card_num
layers=24
num_query=32
num_heads=32
enc_seq_len=2048
hs=4096
ffn_hs=16384
vocab=50304

sequences_per_sec=465.45
seconds_per_step=gbs/sequences_per_sec

#Model total parameters:
params_qkv_state = (1+2*(num_query/num_heads))*hs*hs
params_post_attention_linear = hs*hs
params_fead_forward_network = 2*hs*ffn_hs
params_vocabulary_embedding = hs*vocab

#FPROP:
qkv_state = gbs*2*(1+2*(num_query/num_heads))*enc_seq_len*hs*hs
attention_matrix_computation = gbs*2*enc_seq_len*enc_seq_len*hs
attention_over_values = gbs*2*enc_seq_len*enc_seq_len*hs
post_attention_linear_projection = gbs*2*enc_seq_len*hs*hs
fead_forward_network = gbs*(2*2*enc_seq_len*ffn_hs*hs)
vocabulary_embedding = gbs*2*enc_seq_len*hs*vocab

#BPROP:
#FPROP*2

model_params = (params_qkv_state+params_post_attention_linear+params_fead_forward_network)*layers + params_vocabulary_embedding 
model_float = 3*((qkv_state+attention_matrix_computation+attention_over_values+post_attention_linear_projection+fead_forward_network)*layers + vocabulary_embedding) 
model_flops = model_float/seconds_per_step
cluster_ideal_flops = 312*(10**12) * card_num
MFU = model_flops/cluster_ideal_flops
print("Model parameters {:4f}B MFU={:4f}%".format(model_params/(10**9),MFU*100))