Open colehaus opened 1 month ago
Hmm, I'm a little mystified by this, because this was something I thought we added support for (https://github.com/patrick-kidger/equinox/issues/259, https://github.com/patrick-kidger/equinox/commit/c5fc44f4acff02f1b2c24f5f39f009c1b5ff5967).
Indeed in the line just above your error, we have an explicit
if typeold is jax.ShapeDtypeStruct:
typeold = array_impl_type
check to cast away ShapeDtypeStruct
s.
Ah, yeah, I think the issue is because we're in a slightly unusual case where we actually want a numpy/host array returned while array_impl_type
assumes we want a JAX/device array. If I remove the jax.device_get
part on the custom filter_spec
, then it works fine.
Right! So I think what you're trying to do here is reasonable. I'd be happy to take a PR adjusting this. (Maybe we just consider all kinds of JAX and NumPy array interchangeable?)
Suppose you have a large pytree where you want to ensure that the full tree is never on the JAX device (TPU/GPU). You might also want to minimize the allocation of transient arrays by using
eval_shape
. Your ser/de code would then look something like this:But that errors with:
(The error message is slightly misleading in this case because the actual comparison we're performing and failing is between
jaxlib.xla_extension.ArrayImpl
(i.e.array_impl_type
) andnumpy.ndarray
.)Note that users can circumvent the issue by monkey-patching out the check but that's pretty ugly: