jax-ml / ml_dtypes

A stand-alone implementation of several NumPy dtype extensions used in machine learning.
Apache License 2.0
191 stars 25 forks source link

Add float8_e3m4 #171

Closed apivovarov closed 2 weeks ago

apivovarov commented 3 weeks ago

This PR adds f8E3M4 type.

f8E3M4 type follows IEEE 754 convention

f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa), 
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)

Related LLVM PRs:

RFCs:

Related ml_dtypes PRs:

C++ Testing

Tested as described in PR-123 Add CMakeLists for C++ tests

cmake --build build -- all test

100% tests passed, 0 tests failed out of 307

Total Test time (real) =  14.88 sec

pre-commit code validation

pre-commit run --all-files
check python ast.........................................................Passed
check for merge conflicts................................................Passed
check toml...............................................................Passed
check yaml...............................................................Passed
fix end of files.........................................................Passed
trim trailing whitespace.................................................Passed
debug statements (python)................................................Passed
pyink....................................................................Passed
ruff.....................................................................Passed
apivovarov commented 3 weeks ago

rebased on top of BUG: fix float divmod with zero denominator

apivovarov commented 2 weeks ago

Hi Peter, Could you please take a look at this PR? It adds support for float8_e3m4. This PR is a clone of the previously merged PR #161, which added support for float8_e4m3.

@hawkinsp