patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

An array-like defined with `eqx.Module` does not support `jnp.asarray` with the `__array__` method #898

Closed gautierronan closed 22 hours ago

gautierronan commented 1 day ago

I am running into the following bug (which may come from JAX) when using eqx.Module. I want to define an array-like object on which I can call jnp.asarray. I achieve this by defining the numpy .__array__ method, which, as per the documentation of jnp.asarray should be working.

When doing this with a regular class, it works as intended. However, when subclassing eqx.Module, I run into TypeError: Unexpected input type for array: <class '__main__.MyEqxArray'>.

MWE:

import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx

class MyArray:
    def __init__(self, x):
        self.x = x

    def __array__(self, dtype=None, copy=None) -> np.ndarray:
        return np.asarray(self.x, dtype=dtype)

class MyEqxArray(eqx.Module):
    x: jax.Array

    def __array__(self, dtype=None, copy=None) -> np.ndarray:
        return np.asarray(self.x, dtype=dtype)

x = MyArray(jnp.array([1, 2, 3]))
jnp.asarray(x)
# works as expected

y = MyEqxArray(jnp.array([1, 2, 3]))
jnp.asarray(y)
# TypeError: Unexpected input type for array: <class '__main__.MyEqxArray'>

Stack trace:

TypeError                                 Traceback (most recent call last)
<ipython-input-31-4d23b83b2dea> in ?()
     20 x = MyArray(jnp.array([1, 2, 3]))
     21 jnp.asarray(x)
     22 
     23 y = MyEqxArray(jnp.array([1, 2, 3]))
---> 24 jnp.asarray(y)

~/miniconda3/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py in ?(a, dtype, order, copy, device)
   4248                       "Consider using copy=None or copy=True instead.")
   4249   dtypes.check_user_dtype_supported(dtype, "asarray")
   4250   if dtype is not None:
   4251     dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)  # type: ignore[assignment]
-> 4252   return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)

~/miniconda3/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py in ?(object, dtype, copy, order, ndmin, device)
   4081     object = memoryview(object)
   4082     # TODO(jakevdp): update this once we support NumPy 2.0 semantics for the copy arg.
   4083     out = np.array(object) if copy else np.asarray(object)
   4084   else:
-> 4085     raise TypeError(f"Unexpected input type for array: {type(object)}")
   4086   out_array: Array = lax_internal._convert_element_type(
   4087       out, dtype, weak_type=weak_type, sharding=sharding)
   4088   if ndmin > ndim(out_array):

TypeError: Unexpected input type for array: <class '__main__.MyEqxArray'>
johannahaffner commented 1 day ago

Hi,

this is because JAX handles arrays a bit differently, I think. If I switch to the magic method __jax_array__, it works. In that case, there is also no mix between numpy and jax.numpy, which don't always do the same thing under the hood and could potentially lead to confusing errata arising from their interactions.

Here is the modified MWE, which now works:

import jax.numpy as jnp
import equinox as eqx
from jaxtyping import Array

class MyArray:
    def __init__(self, x):
        self.x = x

    def __jax_array__(self) -> Array:
        return self.x

class MyEqxArray(eqx.Module):
    x: Array

    def __jax_array__(self) -> Array:
        return self.x

x = MyArray(jnp.array([1, 2, 3]))
print(jnp.asarray(x))

y = MyEqxArray(jnp.array([1, 2, 3]))
print(jnp.asarray(y))

Depending on what purpose you want to extract arrays for, the eqx.partition and eqx.combine methods allow you to split models into array and non-array components, as described here. I'm not quite sure what the best practice on magic methods is, though, since they are handled differently from regular methods, as Patrick stated here.

Note: I've also switched to jaxtyping for the type annotation, but that should have no bearing on the errata avoided.

johannahaffner commented 1 day ago

PS: Take my recommendation with a few grains of salt, the future of __jax_array__ appears uncertain. Way to fascinating for ordinary coffee breaks 😃

patrick-kidger commented 1 day ago

Looks like this is because Equinox modules are trees, so they get flattened here:

https://github.com/jax-ml/jax/blob/1bc9df429d87920bdbbf874e84a63fbe3111e27d/jax/_src/numpy/lax_numpy.py#L5643

As Johanna highlights, I think __array__ and __jax_array__ should probably not be considered fully supported in JAX.

gautierronan commented 22 hours ago

Right, thanks for both of your answers. This issue about __jax_array__ keeps coming up and is fairly annoying I have to say haha! Its lack of support is actually why I only implemented __array__ in my example above.

I guess I should be raising a bug report in JAX, as the documentation of jnp.asarray does state explicity that a class with the __array__ method should be supported as input.