patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX.
Apache License 2.0
2.12k stars 142 forks source link

`eval_shape` incompatible with deserializing directly to host due to `__assert_same` #861

Open colehaus opened 1 month ago

colehaus commented 1 month ago

Suppose you have a large pytree where you want to ensure that the full tree is never on the JAX device (TPU/GPU). You might also want to minimize the allocation of transient arrays by using eval_shape. Your ser/de code would then look something like this:

from __future__ import annotations

from import Callable
from pathlib import Path
from typing import TypeVar, TypeVarTuple

import equinox as eqx
import equinox._serialisation as eqx_ser
import jax
from numpy import ndarray

Shape = TypeVarTuple("Shape")
DType = TypeVar("DType")

def save(path: Path, array: ndarray[*Shape, DType]):
    with"wb") as f:
        # We have to convert to JAX arrays because numpy doesn't handle bfloat16
        eqx.tree_serialise_leaves(f, array)

def load(path: Path, like_fn: Callable[[], ndarray[*Shape, DType]]) -> ndarray[*Shape, DType]:
    with"rb") as f:
        return eqx.tree_deserialise_leaves(
            filter_spec=lambda f, x: jax.device_get(eqx.default_deserialise_filter_spec(f, x)),

But that errors with:

File …/lib/python3.11/site-packages/equinox/, in _assert_same.<locals>._assert_same_impl(path, new, old)
    170     typeold = array_impl_type
    171 if typenew is not typeold:
--> 172     raise RuntimeError(
    173         f"Deserialised leaf at path '{jtu.keystr(path)}' has changed type from "
    174         f"{type(old)} in `like` to {type(new)} on disk."
    175     )
    176 if isinstance(new, (np.ndarray, jax.Array)):
    177     if new.shape != old.shape:

RuntimeError: Deserialised leaf at path '' has changed type from <class 'jax._src.api.ShapeDtypeStruct'> in `like` to <class 'numpy.ndarray'> on disk.

(The error message is slightly misleading in this case because the actual comparison we're performing and failing is between jaxlib.xla_extension.ArrayImpl (i.e. array_impl_type) and numpy.ndarray.)

Note that users can circumvent the issue by monkey-patching out the check but that's pretty ugly:

def patched_assert_same(array_impl_type):  # type: ignore
    """Equinox generates a fixed `array_impl_type` that corresponds to a JAX array.
    Then `_assert_same_impl` swaps in this type for any `jax.ShapeDtypeStruct` in the `like`.
    It then compares the `like` types and types at the very end of deserialization.
    But this means we're forbidden from deserializing to the host with `eval_shape` and
    would instead have to deserialize the whole tree on device and transfer to the host.

    def _assert_same_impl(path, new, old):  # type: ignore

    return _assert_same_impl

eqx_ser._assert_same = patched_assert_same  # type: ignore
patrick-kidger commented 1 month ago

Hmm, I'm a little mystified by this, because this was something I thought we added support for (,

Indeed in the line just above your error, we have an explicit

if typeold is jax.ShapeDtypeStruct:
    typeold = array_impl_type

check to cast away ShapeDtypeStructs.

colehaus commented 1 month ago

Ah, yeah, I think the issue is because we're in a slightly unusual case where we actually want a numpy/host array returned while array_impl_type assumes we want a JAX/device array. If I remove the jax.device_get part on the custom filter_spec, then it works fine.

patrick-kidger commented 1 month ago

Right! So I think what you're trying to do here is reasonable. I'd be happy to take a PR adjusting this. (Maybe we just consider all kinds of JAX and NumPy array interchangeable?)