Closed sergey-kozub closed 1 month ago
Note: a small unrelated change in "_finfo.py" removes unreadable boilerplate and replaces it with (faster) dict lookups for instantiating "finfo" objects.
I'm trying to understand the relationship between these types and the MX types. From my quick read of the MX spec, all of the types it defines are block-scaled formats, which these types are not?
Can you say more about the relationship and the use case for these?
I'm trying to understand the relationship between these types and the MX types. From my quick read of the MX spec, all of the types it defines are block-scaled formats, which these types are not?
The MXFP8 type is a pair of tensors (e.g., 1st could have the E5M2 type, 2nd - the E8M0 type with 32x less elements).
Proper support of such MX type (where the value has two different primitive types) is way too complicated, but we could instead use two values. This way a dot op with scaled inputs (what we're actually interested in) could be represented as a custom call with four input tensors.
So, in order to implement MXFP8, we need E8M0 primitive type in XLA (and E5M2/E4M3 already exist). For MXFP4, we need both E8M0 and E2M1. Adding FP6 types (E2M3 and E3M2) just for completeness, they are very similar and will unblock us in the future. All of these types are described in the MX spec: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
This PR adds MX (microscaling) floating point types support.
F4e2m1
,F6e2m3
,F6e3m2
types are proposed in OpenCompute MX Specification.These types have the following notable features:
nan
encoding, only finite values are supported;inf
encoding, similar to the existing 8-bit types withfn
suffix;int2
andint4
types.Related PRs: