Open colehaus opened 1 year ago
This is actually intentional behaviour. As you've noticed, NumPy uses a different type for scalar arrays, whilst JAX doesn't.
However, we often treat that JAX arrays and NumPy arrays as being somehow "morally equivalent". This means we want the NumPy equivalent of a JAX scalar array to be also be treated as an array. I believe in some cases JAX will actually substitute JAX arrays for NumPy arrays (I seem to recall this happening when using JAX_DEBUG_NANS=1
), so something like eqx.is_array(jnp.array(1) + 1)
might otherwise produce changeable behaviour.
I started looking at PR for https://github.com/patrick-kidger/equinox/issues/486. While doing so, I noticed another potential issue:
is_array
'sisinstance
check includesnp.generic
. This does not seem to align with the intention fornp.generic
according to the documentation where it's described as for scalars.If I remove
np.generic
, the only test that breaks is this one intest_callback
. And the "breakage" there is actually revealing an error in the test IMO.The
out_struct
suggests that the return type will be a zero-dimensional array. That would be true ifx
were a JAX array, but numpy arrays and JAX arrays have a discrepancy here:And the numpy behavior is the relevant one here since:
according to the docs.
So
f
in the test is actually returning afloat32
scalar instead of a zero-dimensionalfloat32
array.All of which is a somewhat long-winded way of saying that
np.generic
can be safely removed as far as Equinox tests are removed with a slight tweak to this one test:return (x + 1).astype(np.float32), sentinel2
→return np.array((x + 1)).astype(np.float32), sentinel2
How this change would affect user code is, of course, less clear.