Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
458
stars
69
forks
source link
Support FP8 params updating for NVIDIA Hopper GPUs #44
In this PR, we support configuring FP8 GEMMs and updating the FP8 params. This PR depends on the praxis PR, which introduces the FP8 GEMMs to the praxis layers. Users can simply set the USE_FP8 = True to turn on the FP8 layers (i.e. the QKV projection, attention output projection, and feedforward layers) in the transformer layer.
In this PR, we support configuring FP8 GEMMs and updating the FP8 params. This PR depends on the praxis PR, which introduces the FP8 GEMMs to the praxis layers. Users can simply set the
USE_FP8 = True
to turn on the FP8 layers (i.e. the QKV projection, attention output projection, and feedforward layers) in the transformer layer.cc. @pjannaty @reedwm