Closed gautierronan closed 22 hours ago
Hi,
this is because JAX handles arrays a bit differently, I think. If I switch to the magic method __jax_array__
, it works.
In that case, there is also no mix between numpy
and jax.numpy
, which don't always do the same thing under the hood and could potentially lead to confusing errata arising from their interactions.
Here is the modified MWE, which now works:
import jax.numpy as jnp
import equinox as eqx
from jaxtyping import Array
class MyArray:
def __init__(self, x):
self.x = x
def __jax_array__(self) -> Array:
return self.x
class MyEqxArray(eqx.Module):
x: Array
def __jax_array__(self) -> Array:
return self.x
x = MyArray(jnp.array([1, 2, 3]))
print(jnp.asarray(x))
y = MyEqxArray(jnp.array([1, 2, 3]))
print(jnp.asarray(y))
Depending on what purpose you want to extract arrays for, the eqx.partition
and eqx.combine
methods allow you to split models into array and non-array components, as described here.
I'm not quite sure what the best practice on magic methods is, though, since they are handled differently from regular methods, as Patrick stated here.
Note: I've also switched to jaxtyping for the type annotation, but that should have no bearing on the errata avoided.
PS: Take my recommendation with a few grains of salt, the future of __jax_array__
appears uncertain. Way to fascinating for ordinary coffee breaks 😃
Looks like this is because Equinox modules are trees, so they get flattened here:
As Johanna highlights, I think __array__
and __jax_array__
should probably not be considered fully supported in JAX.
Right, thanks for both of your answers. This issue about __jax_array__
keeps coming up and is fairly annoying I have to say haha! Its lack of support is actually why I only implemented __array__
in my example above.
I guess I should be raising a bug report in JAX, as the documentation of jnp.asarray
does state explicity that a class with the __array__
method should be supported as input.
I am running into the following bug (which may come from JAX) when using
eqx.Module
. I want to define an array-like object on which I can calljnp.asarray
. I achieve this by defining the numpy.__array__
method, which, as per the documentation ofjnp.asarray
should be working.When doing this with a regular class, it works as intended. However, when subclassing
eqx.Module
, I run intoTypeError: Unexpected input type for array: <class '__main__.MyEqxArray'>
.MWE:
Stack trace: