google / paxml

Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
458 stars 69 forks source link

Support FP8 params updating for NVIDIA Hopper GPUs #44

Closed kaixih closed 1 year ago

kaixih commented 1 year ago

In this PR, we support configuring FP8 GEMMs and updating the FP8 params. This PR depends on the praxis PR, which introduces the FP8 GEMMs to the praxis layers. Users can simply set the USE_FP8 = True to turn on the FP8 layers (i.e. the QKV projection, attention output projection, and feedforward layers) in the transformer layer.

cc. @pjannaty @reedwm

pjannaty commented 1 year ago

cc @zhangqiaorjc

pjannaty commented 1 year ago

cc @lew

kaixih commented 1 year ago

Closing as this is resubmitted in another PRs.