google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.79k stars 609 forks source link

[NVIDIA] Rename fp8 custom dtype to `fp32_max_grad` #3984

Open kaixih opened 3 weeks ago

kaixih commented 3 weeks ago

This PR renames the original fm32 to fp32_max_grad to express the idea of the dtype is used for storing fp32 values and using max for the gradient accumulation.

cc. @nouiz

IvyZX commented 3 days ago

LGTM - please rebase and clear the CI test checks then I can merge it.

kaixih commented 3 days ago

@IvyZX Thanks for the comments. Just resolved the conflict. PTAL.

codecov-commenter commented 3 days ago

Codecov Report

Attention: Patch coverage is 0% with 6 lines in your changes missing coverage. Please review.

Project coverage is 0.00%. Comparing base (31adb00) to head (798cfe7). Report is 59 commits behind head on main.

Files Patch % Lines
flax/linen/fp8_ops.py 0.00% 6 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #3984 +/- ## ====================================== Coverage 0.00% 0.00% ====================================== Files 106 107 +1 Lines 13582 13761 +179 ====================================== - Misses 13582 13761 +179 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.