google / paxml

Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
458 stars 69 forks source link

[NVIDIA] Support new config option `USE_FP8` #49

Closed kaixih closed 1 year ago

kaixih commented 1 year ago

This PR introduces a new performance-related config option USE_FP8, which will call the provided function tr_set_fp8_quantization in praxis to set the recommended layers inside the transformer to utilize the FP8 GEMM.

There are four related PRs, and should be reviewed in this order: (1) https://github.com/google/praxis/pull/29 (2) https://github.com/google/paxml/pull/48 (3) https://github.com/google/praxis/pull/28 current-->(4) https://github.com/google/paxml/pull/49

cc. @pjannaty @reedwm @nluehr @lukaszlew

kaixih commented 1 year ago

@zhangqiaorjc I've noticed that this pull request has been in a "pull ready" status for a couple of days. Is there a specific action needed, like clicking a button to merge the PR, or will the process be automated by a robot?