jax-ml / ml_dtypes

A stand-alone implementation of several NumPy dtype extensions used in machine learning.
Apache License 2.0
191 stars 25 forks source link

WIP: Add MX `E8M0` datatype to `ml_dtypes` #163

Closed balancap closed 1 month ago

balancap commented 1 month ago

Implementing support for E8M0 datatype in ml_dtypes, following the definition in the MX format OCP document.

A few questions still to be solved: