jax-ml / ml_dtypes

A stand-alone implementation of several NumPy dtype extensions used in machine learning.
Apache License 2.0
197 stars 26 forks source link

`np.testing.assert_array_equal()` not compatible with bfloat16 when the value is `nan` #206

Open ayaka14732 opened 1 week ago

ayaka14732 commented 1 week ago
import numpy as np
import jax.numpy as jnp

a = jnp.array([jnp.nan], dtype=jnp.float32)
np.testing.assert_array_equal(a, a)  # No error

a = jnp.array([jnp.nan], dtype=jnp.bfloat16)
np.testing.assert_array_equal(a, a)  # AssertionError

The cause is that np.testing.assert_array_equal() does not recognise bfloat16 as a "number" type: https://github.com/numpy/numpy/blob/b3ddf2fd33232b8939f48c7c68a61c10257cd0c5/numpy/testing/_private/utils.py#L773

jakevdp commented 1 week ago

This is one of the many examples you can probably find of where NumPy hard-codes logic about its built-in set of dtypes. There's nothing that downstream dtype implementations like ml_dtypes can do to change this. We might think about raising this issue upstream in the NumPy package.