jax-ml / ml_dtypes

A stand-alone implementation of several NumPy dtype extensions used in machine learning.
Apache License 2.0
191 stars 25 forks source link

f16 sNaN is casted to f8e5m2 qNaN in Half_To_Float8E5m2 test #162

Open apivovarov opened 1 month ago

apivovarov commented 1 month ago

qNaN vs sNaN

f16 qNaN 0.11111.1000000000
f16 sNaN 0.11111.0100000000 (one of the examples)

Problematic Test is Half_To_Float8E5m2 -

  Eigen::half nan =
      Eigen::numext::bit_cast<Eigen::half>(static_cast<uint16_t>(0x7C01));
  EXPECT_EQ(static_cast<float8_e5m2>(nan).rep(), 0x7E);

input is 0x7C01 which is sNaN 0.11111.0000000001 expected result is 0x7E which is qNaN 0.11111.10 - This is incorrect, as the type of NaN should be preserved in static_cast.

Instead, the expected result should be 0.11111.01 - 0x7D

apivovarov commented 1 month ago

@cantonios @jakevdp @hawkinsp what you think?

cantonios commented 1 month ago

@cantonios @jakevdp @hawkinsp what you think?

Yeah, it looks like we always quiet the NaN on conversions. We should be satisfying IEEE 6.2.3, preserving the signaling bit and payload as much as possible:

Conversion of a quiet NaN from a narrower format to a wider format in the same radix, and then back to the same narrower format, should not change the quiet NaN payload in any way except to make it canonical. Conversion of a quiet NaN to a floating-point format of the same or a different radix that does not allow the payload to be preserved shall return a quiet NaN that should provide some language-defined diagnostic information.

I'd be happy to look at a PR to fix this here and here