jax-ml / ml_dtypes

A stand-alone implementation of several NumPy dtype extensions used in machine learning.
Apache License 2.0
191 stars 25 forks source link

Add float8_e4m3 #161

Closed apivovarov closed 3 weeks ago

apivovarov commented 1 month ago

This PR adds f8E4M3 type.

f8E4M3 type follows IEEE 754 convention

f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa), 
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)

Related LLVM PRs:

RFCs:

Related ml_dtypes PRs:

C++ Testing

Tested as described in PR-123 Add CMakeLists for C++ tests

cmake --build build -- all test

100% tests passed, 0 tests failed out of 250

Total Test time (real) =  12.50 sec
apivovarov commented 1 month ago

Peter, Jake, whenever you have time, could you please review this pull request? f8E4M3 support has already been merged to LLVM APFloat and MLIR. I also need to add f8E4M3 to XLA. Internally XLA uses tsl which depends on ml_dtypes - tsl/platform/ml_dtypes.h

@hawkinsp @jakevdp

hawkinsp commented 1 month ago

@apivovarov I'm on vacation this week, but I'll take a look next week. Sorry for the delay...

apivovarov commented 3 weeks ago

@apivovarov I'm on vacation this week, but I'll take a look next week. Sorry for the delay...

Hi Peter, it seems this PR got overlooked. @hawkinsp

BTW, I've also attached a link to the StableHLO [RFC] Add f8E4M3 and f8E3M4 types support https://github.com/openxla/stablehlo/pull/2486

Not sure what change should be merged first StableHLO or ml_dtypes.

hawkinsp commented 3 weeks ago

The linter is also sad, please fix.

hawkinsp commented 3 weeks ago

And it doesn't matter whether the stablehlo change or the ml_dtypes change is merged first, really, they aren't directly coupled.

hawkinsp commented 3 weeks ago

Linter is still sad.

apivovarov commented 3 weeks ago

Linter is still sad.

Fixed. pre-commit run --all-files - all green now