patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.13k stars 142 forks source link

`is_array` returns `True` for instances of `np.generic` #507

Open colehaus opened 1 year ago

colehaus commented 1 year ago

I started looking at PR for https://github.com/patrick-kidger/equinox/issues/486. While doing so, I noticed another potential issue: is_array's isinstance check includes np.generic. This does not seem to align with the intention for np.generic according to the documentation where it's described as for scalars.

If I remove np.generic, the only test that breaks is this one in test_callback. And the "breakage" there is actually revealing an error in the test IMO.

    def f(x, y):
        assert y is sentinel1
        return (x + 1).astype(np.float32), sentinel2

    out_struct = (jax.ShapeDtypeStruct((), jnp.float32), sentinel2)

    out = eqx.filter_pure_callback(
        f, jnp.array(1.0), sentinel1, result_shape_dtypes=out_struct
    )

The out_struct suggests that the return type will be a zero-dimensional array. That would be true if x were a JAX array, but numpy arrays and JAX arrays have a discrepancy here:

>>> jnp.array(1.0) + 1
Array(2., dtype=float32, weak_type=True)
>>> np.array(1.0) + 1
2.0

And the numpy behavior is the relevant one here since:

The input callback will be passed NumPy arrays in place of JAX arrays and should also return NumPy arrays.

according to the docs.

So f in the test is actually returning a float32 scalar instead of a zero-dimensional float32 array.

All of which is a somewhat long-winded way of saying that np.generic can be safely removed as far as Equinox tests are removed with a slight tweak to this one test:

return (x + 1).astype(np.float32), sentinel2return np.array((x + 1)).astype(np.float32), sentinel2

How this change would affect user code is, of course, less clear.

patrick-kidger commented 1 year ago

This is actually intentional behaviour. As you've noticed, NumPy uses a different type for scalar arrays, whilst JAX doesn't.

However, we often treat that JAX arrays and NumPy arrays as being somehow "morally equivalent". This means we want the NumPy equivalent of a JAX scalar array to be also be treated as an array. I believe in some cases JAX will actually substitute JAX arrays for NumPy arrays (I seem to recall this happening when using JAX_DEBUG_NANS=1), so something like eqx.is_array(jnp.array(1) + 1) might otherwise produce changeable behaviour.