jax-ml / ml_dtypes

A stand-alone implementation of several NumPy dtype extensions used in machine learning.
Apache License 2.0
206 stars 28 forks source link

Add sub-byte data types: float4_e2m1fn, float6_e2m3fn, float6_e3m2fn #181

Closed sergey-kozub closed 1 month ago

sergey-kozub commented 1 month ago

This PR adds MX (microscaling) floating point types support.

F4e2m1, F6e2m3, F6e3m2 types are proposed in OpenCompute MX Specification.

These types have the following notable features:

float4_e2m1fn
- Exponent bias: 1
- Maximum stored exponent value: 3 (binary 11)
- Maximum unbiased exponent value: 3 - 1 = 2
- Minimum stored exponent value: 1 (binary 01)
- Minimum unbiased exponent value: 1 − 1 = 0
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.00.0
- Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6
- Min normal number: S.01.0 = ±2^(0) = ±1
- Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5
float6_e2m3fn
- Exponent bias: 1
- Maximum stored exponent value: 3 (binary 11)
- Maximum unbiased exponent value: 3 - 1 = 2
- Minimum stored exponent value: 1 (binary 01)
- Minimum unbiased exponent value: 1 − 1 = 0
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.00.000
- Max normal number: S.11.111 = ±2^(2) x (1 + 0.875) = ±7.5
- Min normal number: S.01.000 = ±2^(0) = ±1
- Max subnormal number: S.00.111 = ±2^(0) x 0.875 = ±0.875
- Min subnormal number: S.00.001 = ±2^(0) x 0.125 = ±0.125
float6_e3m2fn
- Exponent bias: 3
- Maximum stored exponent value: 7 (binary 111)
- Maximum unbiased exponent value: 7 - 3 = 4
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs

Additional details:
- Zeros (+/-): S.000.00
- Max normal number: S.111.11 = ±2^(4) x (1 + 0.75) = ±28
- Min normal number: S.001.00 = ±2^(-2) = ±0.25
- Max subnormal number: S.000.11 = ±2^(-2) x 0.75 = ±0.1875
- Min subnormal number: S.000.01 = ±2^(-2) x 0.25 = ±0.0625

Related PRs:

sergey-kozub commented 1 month ago

Note: a small unrelated change in "_finfo.py" removes unreadable boilerplate and replaces it with (faster) dict lookups for instantiating "finfo" objects.

hawkinsp commented 1 month ago

I'm trying to understand the relationship between these types and the MX types. From my quick read of the MX spec, all of the types it defines are block-scaled formats, which these types are not?

Can you say more about the relationship and the use case for these?

sergey-kozub commented 1 month ago

I'm trying to understand the relationship between these types and the MX types. From my quick read of the MX spec, all of the types it defines are block-scaled formats, which these types are not?

The MXFP8 type is a pair of tensors (e.g., 1st could have the E5M2 type, 2nd - the E8M0 type with 32x less elements).

Proper support of such MX type (where the value has two different primitive types) is way too complicated, but we could instead use two values. This way a dot op with scaled inputs (what we're actually interested in) could be represented as a custom call with four input tensors.

So, in order to implement MXFP8, we need E8M0 primitive type in XLA (and E5M2/E4M3 already exist). For MXFP4, we need both E8M0 and E2M1. Adding FP6 types (E2M3 and E3M2) just for completeness, they are very similar and will unblock us in the future. All of these types are described in the MX spec: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf