jax-ml / ml_dtypes

A stand-alone implementation of several NumPy dtype extensions used in machine learning.
Apache License 2.0
191 stars 25 forks source link

Support for non-saturating mode for fp8 #147

Open wonjeon opened 4 months ago

wonjeon commented 4 months ago

OFP8 (https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1) defines two saturation modes - saturating and non-saturating mode. It looks like the current code runs in saturating mode by default.

>>> math.isnan(np.float32(449).astype(float8_e4m3fn))
False

Wondering if there's any plan to support both. Thank you.