jax-ml / ml_dtypes

A stand-alone implementation of several NumPy dtype extensions used in machine learning.
Apache License 2.0
206 stars 28 forks source link

np.floor_divide(?, ml_dtypes.???(0.0)) return NaN but np.float16 returns Inf. #170

Closed apivovarov closed 2 months ago

apivovarov commented 2 months ago

If the second argument of np.floor_divide is 0.0, it returns different results depending on the type: np.dtypes returns Inf, while ml_dtypes returns NaN

Related test: ml_dtypes/tests/custom_float_test.py testBinaryUfunc. The test fails if y has 0.0 elements.

Manual run to confirm the difference in behavior

>>> np.floor_divide(ml_dtypes.float8_e4m3(1.0), ml_dtypes.float8_e4m3(0.0))
nan
>>> np.floor_divide(ml_dtypes.float8_e5m2(1.0), ml_dtypes.float8_e5m2(0.0))
nan
>>> np.floor_divide(ml_dtypes.bfloat16(1.0), ml_dtypes.bfloat16(0.0))
nan
>>> np.floor_divide(np.float16(1.0), np.float16(0.0))
np.float16(inf)
jakevdp commented 2 months ago

Hi - thanks for the report! It looks like that output comes from this line: https://github.com/jax-ml/ml_dtypes/blob/30f2497888b90db1c9f775ed1279315acdf44f30/ml_dtypes/_src/ufuncs.h#L171

We probably need to update this to condition the returned value on the value of a.

I think it should return {-inf, nan} if a < 0, {nan, nan} if a == 0, and {inf, nan} if a > 0.

jakevdp commented 2 months ago

Thanks again for the report! We'll try to cut a new release soon with this fix. Let us know if you run into other issues

apivovarov commented 2 months ago

Thank you! It would be great if we could also include float8_e3m4 in the new release.

PR-171 which adds float8_e3m4 is a clone of the previously merged PR-161 that added float8_e4m3.

One of the float8_e3m4 tests exposed an issue with floor_divide when random data was generated for float8_e3m4 type.