Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
456
stars
68
forks
source link
[Question] Very low MFU(30%~35%) when train bf16 Llama2 and GPT model with single SXM4 A100 machine. #65
I don't know what happened, is the calculation precision and parameter precision not set correctly? Deepspeed or Megatron could achieve 55% MFU easily with same machine.
Here is my bash script:
According https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax, Nvidia train a 5B GPT model with Nativ BF16 in 256 A100 GPU. And its performance 465.45 Sequences/Sec when sequences global batch size is 8*256=2048. So it means it costed 4.4s per step. Am I correct?
This script could calculate its MFU which is 38.958427%. It's too low!
I don't know what happened, is the calculation precision and parameter precision not set correctly? Deepspeed or Megatron could achieve 55% MFU easily with same machine. Here is my bash script:
According https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax, Nvidia train a 5B GPT model with Nativ BF16 in 256 A100 GPU. And its performance 465.45 Sequences/Sec when sequences global batch size is 8*256=2048. So it means it costed 4.4s per step. Am I correct? This script could calculate its MFU which is 38.958427%. It's too low!