Open ayaka14732 opened 1 week ago
This is one of the many examples you can probably find of where NumPy hard-codes logic about its built-in set of dtypes. There's nothing that downstream dtype implementations like ml_dtypes
can do to change this. We might think about raising this issue upstream in the NumPy package.
The cause is that
np.testing.assert_array_equal()
does not recognise bfloat16 as a "number" type: https://github.com/numpy/numpy/blob/b3ddf2fd33232b8939f48c7c68a61c10257cd0c5/numpy/testing/_private/utils.py#L773