patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.11k stars 142 forks source link

Add some kind de/serialisation? #46

Closed patrick-kidger closed 2 years ago

patrick-kidger commented 2 years ago

Equinox models are just PyTrees so they should be very easy to serialise/deserialise; just save the PyTree to disk in whatever way is desired. It might be worth adding some library functions for this just for convenience. Perhaps checking the device of JAX arrays etc?

This should respect the get_state/set_state stuff that's being put together.

In addition, there should be a version of get_state which inlines its state in the call graph, for faster inference.

jaschau commented 2 years ago

I would like to pick up on a comment you made in https://www.reddit.com/r/MachineLearning/comments/u34oh2/d_what_jax_nn_library_to_use/i4umg44/?context=3.

The reason I'm dragging my heels on this is that I'm not yet completely sold on whether to use the code above, or to instead try and pickle the entire PyTree in one go. The former requires you to be able to produce the PyTree structure yourself (and by default doesn't save e.g. NumPy arrays); the latter has annoying compatibility concerns. (+I think Flax has a couple of ways of doing de/serialisation that might be worth using as inspiration.) So I want to be sure we get this right.

I agree that it would be attractive to have an option to serialize complete PyTrees, but I also see some challenges. Just to give an example, if you have a equinox module,

class MyModule(eqx.Module):
    activation: Any

module = MyModule(activation=jnn.tanh)

pickling this will fail because jnn.tanh seems to point to a lambda and pickle doesn't support pickling lambdas (https://stackoverflow.com/questions/25348532/can-python-pickle-lambda-functions).

But until we have a solution for this, why not go with something along the lines of

def save_model_weights(model: eqx.Module, path: str):
    tree_weights, tree_other = eqx.partition(model, eqx.is_inexact_array)
    with open(path + "_weights.npy", "wb") as f:
        for x in jax.tree_leaves(tree_weights):
            np.save(f, x, allow_pickle=False)

def load_model_weights(model, path):
    tree_weights, tree_other = eqx.partition(model, eqx.is_inexact_array)

    leaves_orig, treedef = jax.tree_flatten(tree_weights)
    with open(path + "_weights.npy", "rb") as f:
        flat_state = [jnp.asarray(np.load(f)) for _ in leaves_orig]
    tree_weights = jax.tree_unflatten(treedef, flat_state)
    return eqx.combine(tree_weights, tree_other)

The naming should make it clear that it's just about the model weights. Once we have a good solution for pickling entire trees, we can add methods

def save_model(model: eqx.Module, path: str):
    ...

def load_model(path):
    ...
patrick-kidger commented 2 years ago

Right, pickling is a whole can of worms. Let's leave that alone for now.

Yep, you've convinced me! I think it seems reasonable to add something that saves just the leaves of a PyTree.

Your implementation is also a fair bit tidier than the one I suggested over on reddit; nice. Can you open a PR for this? I'll pre-emptively offer a review - mostly nits.

  1. Change the type annotation from Module to PyTree. Remember, Modules are never special-cased! Likewise, probably change the names to {save,load}_leaves or {de,}serialise_leaves or similar, to help emphasise this.
  2. Allow specifying the filter spec, instead of baking in is_inexact_array. And perhaps make the default is_array instead?
  3. Use jnp.{load,save} over np.{load,save}. I think this is needed to handle bfloat16 dtypes correctly.
  4. Allow path to be a Union[str, pathlib.Path]. I also probably wouldn't hardcode the + "_weights.npy" part of it; I think just pass path through unchanged. (Or only add a .npy suffix.)
  5. Validate that flat_state and leaves_orig match shape+dtype, and raise an error if not.
patrick-kidger commented 2 years ago

Thought. We may also wish to allow customising the use of jnp.{load,save} at some point.

For example: whilst it hasn't happened yet, there's discussion about eventually being able to de/serialise JIT-compiled functions in JAX. Meanwhile, the next release of Equinox will have filter_jit return a new _JitWrapper object: instead of capturing its inputs (function, filter spec, etc.) via closure, it will capture them as leaves in a PyTree instead. Mostly that doesn't change anything, but in this context it offers a neat opportunity. We could de/serialise a filter-jit-wrapped function in the same way as everything else; we'd just need to use whatever mechanism is introduced for de/serialising JIT-compiled functions in place of jnp.{load,save}.

jaschau commented 2 years ago

Hi, thanks for the review and the thoughtful remarks! I agree with most of your comments, but I'm not entirely sold on exposing a function serialize_leaves. There are two reasons a) the naming seems to imply that it can serialize the leaves of an arbitrary tree. In practice, you would be limited to serializing numeric leaves due to the use of jnp.save, so as a user you somehow have to be aware of what should probably be an implementation detail. b) the predominant use case will probably saving weights of a model. At least, that's what I would look for as a user of a neural network library and the connection between a method serialize_leaves and serializing weights might not be immediately clear to a new user. So I would rather make _serialize_leaves a library internal method and expose serialize_model_weights to the user (which internally would call _serialize_leaves). What do you think?

patrick-kidger commented 2 years ago

Hmm. I agree that the points you're making are reasonable.

If we were to use a tree-based name:

As you can probably tell, I'm disinclined to write something as specific save_model_weights. It runs counter to the simple models-are-PyTrees idea that is what makes reasoning about Equinox so easy in the first place. I don't think de/serialisation is an important and special enough problem that it's worth breaking that abstraction.

My feeling is that issues like new-user-discoverability can be handled through appropriate documentation etc. Perhaps we could change the proposed heading above from "Serialisation" to "Serialisation (save/load models to disk)" if we wanted to really make it clear where to look.

WDYT?

jaschau commented 2 years ago

Hi, sorry for the late feedback.

As you can probably tell, I'm disinclined to write something as specific save_model_weights. It runs counter to the simple models-are-PyTrees idea that is what makes reasoning about Equinox so easy in the first place. I don't think de/serialisation is an important and special enough problem that it's worth breaking that abstraction.

I can understand your reasoning here. With respect to (b), I still think that the best documentation is the documentation that you do not need, so one alternative I could think of is introducing methods save_weights and load_weights in eqx.Module which are a one-liner call to tree_serialise_leaves. This would of course increase the amount of code in eqx.Module, but I think that the one-liner would very naturally embody the models-are-PyTrees idea.

Would you be okay with that?

Irrespective of how we do it, I'd be happy to contribute a first draft of a PR sometime this week.

patrick-kidger commented 2 years ago

Is your reasoning that this would improve discoverability because it'll be in the documentation for Module? (Or because it'll appear in dir(Module())?)

At least wrt the former then this could be handled by improving the docstring and I'd be happy to have that in. The problem with adding new methods is that this breaks a different principle in Eqx, which is that no Module method is special cased. (Other than magic methods and tree_{un,}flatten.)

(I appreciate it probably feels like you're threading a needle here.)

jaschau commented 2 years ago

Is your reasoning that this would improve discoverability because it'll be in the documentation for Module? (Or because it'll appear in dir(Module())?)

Yes, exactly. It seems like an obvious place where to look as a user.

At least wrt the former then this could be handled by improving the docstring and I'd be happy to have that in. The problem with adding new methods is that this breaks a different principle in Eqx, which is that no Module method is special cased. (Other than magic methods and tree_{un,}flatten.)

Could you elaborate on what you mean that no Module method is special cased?

At least wrt the former then this could be handled by improving the docstring and I'd be happy to have that in. The problem with adding new methods is that this breaks a different principle in Eqx, which is that no Module method is special cased. (Other than magic methods and tree_{un,}flatten.)

(I appreciate it probably feels like you're threading a needle here.)

No worries at all, I think it is good if library authors are opinionated about their design choices! I appreciate the discussion. Edit: not sure if opinionated carries a negative connotation - if so, it wasn't meant in that way!

patrick-kidger commented 2 years ago

Yes, exactly. It seems like an obvious place where to look as a user.

FWIW I do feel that discoverability could be better in Equinox. For example tree_at is pretty important for modifying PyTrees, but it's not terribly well-advertised.

Could you elaborate on what you mean that no Module method is special cased?

Consider magic methods: these are special cased by Python itself. MyClass.__len__ has meaning beyond simply being a method on a class, and changing magic methods allow you to change how Python handles your class. Likewise there is tree_{un,}flatten, and these allow you to handle how JAX handles your class.

(Another example: PyTorch special-cases forward, as the appropriate extension point for subclasses of torch.nn.Module.)

Equinox doesn't special-case any methods like this. You can't change how Equinox treats your class. (You don't need to.)

Edit: not sure if opinionated carries a negative connotation - if so, it wasn't meant in that way!

Not at all! Not how I read it.

patrick-kidger commented 2 years ago

Also: this should respect eqx.experimental.{StateIndex,get_state,set_state}. Ideally the default should also handle NumPy arrays.

The following is entirely untested, but I think that probably means doing something like:

def _save_index(f, x: StateIndex):
    value = experimental.get_state(x)
    jnp.save(value)

def _save_no(f, x):
    pass

def _load_index(f, x: StateIndex):
    value = jnp.load(f)
    experimental.set_state(x, value)
    return x

def _load_scalar(f, x):
    return np.load(f).item()

def _load_no(f, x):
    return x

def _default_serialise_filter_spec(x):
    if is_array_like(x):
        return jnp.save
    elif isinstance(x, experimental.StateIndex):
        return _save_index
    else:
        return _save_no

def _default_deserialise_filter_spec(x):
    if isinstance(x, jnp.ndarray):
        return jnp.load
    elif isinstance(x, np.ndarray):
        return np.load
    elif isinstance(x, (bool, float, complex, int)):
        return _load_scalar
    elif isinstance(x, experimental.StateIndex):
        return _load_index
    else:
        return _load_no

def _assert_same(new, old):
    if type(new) is not type(old):
        raise RuntimeError(...)
    if isinstance(new, (np.ndarray, jnp.array)) and (new.shape != old.shape or new.dtype != old.dtype):
        raise RuntimeError(...)

def _is_index(x):
    return isinstance(x, experimental.StateIndex)

def tree_serialise_leaves(path: Union[str, pathlib.Path], pytree: PyTree, filter_spec=_default_serialise_filter_spec, is_leaf=_is_index):
    with open(pathlib.Path(path).with_suffix("npy"), "wb") as f:
        def _serialise(spec, x):
            def __serialise(y):
                spec(f, y)
            return jax.tree_map(__serialise, x)
        jax.tree_map(_serialise, filter_spec, pytree, is_leaf=is_leaf)

def tree_deserialise_leaves(path: Union[str, pathlib.Path], like: PyTree, filter_spec=_default_serialise_filter_spec, is_leaf=_is_index):
    with open(pathlib.Path(path).with_suffix("npy"), "rb") as f:
        def _deserialise(spec, x):
            def __deserialise(y):
                return spec(f, y)
            return jax.tree_map(__deserialise, x)
        out = jax.tree_map(_deserialise, filter_spec, like, is_leaf=is_leaf)
    jax.tree_map(_assert_same, out, like, is_leaf=is_leaf)
    return out
patrick-kidger commented 2 years ago

@jaschau heads-up when putting together your PR to branch off the v050 branch.

patrick-kidger commented 2 years ago

@jaschau Any plans to pick this up as a PR? No worries if not -- I'll do it -- I just want to get seralisation+deserialisation into the next release of Equinox.

jaschau commented 2 years ago

Hi @patrick-kidger, sorry for the little progress here. I've been side-tracked with other topics so I haven't come around to working on this beyond our initial discussion. I'm afraid I cannot promise any serious progress on this in the upcoming weeks from my side, so please move ahead if you're eager to work on this.