patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.13k stars 142 forks source link

Pattern for lossless round trip of module serialization without storing side information required by __init__? #535

Open cottrell opened 1 year ago

cottrell commented 1 year ago

Woudn't it make sense to add something to wrap all this to make deserialization work? Currently it's lossy and you need to basically either define a init_param_from_params function or store the init params separately somewhere.

I might be missing something included elsewhere.

And what I mean is that eqx.nn.MLP(2, 2, 2, 2, should not be needed to "deseralize". For example in pure jax you write everything based on params and I simply have jsonifiers of params and it all works. The "init_params" are not needed.

A special classmethod might make sense here. I'm not deep enough into the internals of equinox yet. But the philosophy sounds like it should all separate and merely be a nice way of organizing jax params.

def from_params(...):
   ...

might make sense. I'm not deep enough into the

https://docs.kidger.site/equinox/api/serialisation/

import equinox as eqx
import jax.random as jr

model_original = eqx.nn.MLP(2, 2, 2, 2, key=jr.PRNGKey(0))
eqx.tree_serialise_leaves("some_filename.eqx", model_original)
model_loaded = eqx.tree_deserialise_leaves("some_filename.eqx", model_original)

# To partially load weights: in this case load everything except the final layer.
model_partial = eqx.tree_at(lambda mlp: mlp.layers[-1], model_loaded, model_original)
patrick-kidger commented 1 year ago

So the thing is that you still need to store your forward pass somewhere. That's just some Python function, and you can basically do arbitrary things inside it. Unfortunately, Python doesn't really offer good ways to serialise arbitrary Python functions.

More broadly, your model might include quite a lot of things that aren't parameters (various flags, hyperparameters, etc.) Again, Python doesn't really offer a clean way to serialise arbitrary things like this. (pickle comes close but is fraught with edge-cases.)

The end result is basically that you need to still have access to the source code for say, eqx.nn.MLP, in order to deserialise it.

This is no different in just JAX -- you can serialise your parameters just fine, but again you still need to record the "everything else".

cottrell commented 1 year ago

I think you are misunderstanding. This is about the inputs to init which, if they were to be interpreted as code, would make for and incredibly inconvenient paradigm in the midst of what is an incredibly convenient library.

The code which defines the classes of course needs to be in scope and load time ... a bit like some unpickle a class thing. And of course we would need some convention for identifying keys that correspond to classes but possibly walking the tree from top to bottom and checking the scope would make this possible.

The alternative approach is to create a custom wrapper of init that puts the (unused) init params in the state and uses these on reconstitution merely to create an Module that will then have it's values replaced. This is probably easier but feels like there should be a better approach. This would effectively be like "you need the code" but without literally having new code all over the place just for init params.

Anway, I'll look around for how to get the full pytree out of a Module. It seems like most operations aren't by default recursing through all modules. I imagine you have a util for that somewhere?

patrick-kidger commented 1 year ago

You mean, you're asking to serialise the __init__-time arguments alongside the saved parameters? The idea being that you can then load without an existing model to use a reference?

Anway, I'll look around for how to get the full pytree out of a Module

As for this -- the Module is already a pytree!

cottrell commented 1 year ago

Wait ... isn't this (in examples) the pattern I'm asking about? I somehow didn't see that. Ok this makes more sense.

https://docs.kidger.site/equinox/examples/serialisation/

I'll post some interesting hack in a sec for REALLY going a bit into the rabbit hole of not even assembling the pytree ahead of time. Likely won't generalize.

patrick-kidger commented 1 year ago

Wait ... isn't this (in examples) the pattern I'm asking about? I somehow didn't see that. Ok this makes more sense.

Yup, exactly! ;)

I'll post some interesting hack in a sec for REALLY going a bit into the rabbit hole of not even assembling the pytree ahead of time. Likely won't generalize.

Haha, that sounds exciting. I look forward to seeing it.

cottrell commented 1 year ago

This is probably a horror show, and I'm not super familiar with clever ways to do things in pytrees, but "it works" (maybe) as an example I think.

The only interesting idea is to "ban" over-riding init and use init instead so that you can actually reconstitute the modules directly from something.

I'm just doing some total hack to leave a trace/clue as to the module name in a dict key.

import equinox as eqx
import jax
import jax.numpy as jnp

def recurse_get_state(x):
    if isinstance(x, eqx.Module):
        # NOTE: fragile naming convention
        return {f'module:{type(x).__name__}': recurse_get_state(x.__getstate__())}
    elif isinstance(x, dict):
        return {k: recurse_get_state(v) for k, v in x.items()}
    elif isinstance(x, list):
        return [recurse_get_state(v) for v in x]
    elif isinstance(x, tuple):
        return tuple(recurse_get_state(v) for v in x)
    else:
        return x

def reconstitute_from_root(scope, params):
    out = None
    if isinstance(params, dict):
        if len(params) == 1:
            k, v = list(params.items())[0]
            if isinstance(k, str) and k.startswith('module:'):
                assert len(params) == 1
                name = k.split(':')[1]
                class_ = scope[name]
                params_ = reconstitute_from_root(scope, v)
                out = class_(**params_)
            else:
                out = {k: reconstitute_from_root(scope, v) for k, v in params.items()}
        else:
            out = {k: reconstitute_from_root(scope, v) for k, v in params.items()}
    elif isinstance(params, list):
        out = [reconstitute_from_root(scope, v) for v in params]
    elif isinstance(params, tuple):
        out = tuple(reconstitute_from_root(scope, v) for v in params)
    else:
        out = params
    return out

def reconstitute(scope, params):
    module = reconstitute_from_root(scope, params)
    return module
    if len(module) == 1:
        return module[list(module.keys())[0]]

def check_identical(tree1, tree2):
    def compare_elements(x, y):
        return jnp.all(x == y)

    comparison_tree = jax.tree_map(compare_elements, tree1, tree2)

    return all(jax.tree_util.tree_flatten(comparison_tree)[0])

class Linear(eqx.Module):
    weight: jax.Array
    bias: jax.Array

    @classmethod
    def init(cls, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        weight = jax.random.normal(wkey, (out_size, in_size))
        bias = jax.random.normal(bkey, (out_size,))
        return cls(weight=weight, bias=bias)

class Another(eqx.Module):
    layers: list

    @classmethod
    def init(cls, n, in_size, out_size, key):
        layers = [Linear.init(in_size, out_size, key) for _ in range(n)]
        return cls(layers=layers)

def example():
    key = jax.random.PRNGKey(0)
    in_size = 12
    out_size = 3
    n = 5
    model = Another.init(n, in_size, out_size, key)
    params = recurse_get_state(model)
    model_ = reconstitute(globals(), params)
    print(f'check_identical={check_identical(model, model_)}')
    return model, model_
patrick-kidger commented 1 year ago

Haha, this actually isn't too bad. So the need for an init method can be avoided by using the same trick that Equinox uses during tree-unflattening:

module = object.__new__(YourClass)
fieldnames = {f.name for f in dataclasses.fields(YourClass)}
assert set(params.keys() == fieldnames
for key, value in params.items():
    object.__setattr__(module, key, value)

The "module:" fragility can be handled by also nesting dictionaries with one more layer of dictionaries. (So you serialise a dictionary as {"dict": the_actual_dict}, and a module as {"module": ...}).

You can probably tackle deserialisation by storing both module.__module__ and module.__qualname__ to do lookup. You can probably also serialise functions in the same way.

This still won't be able to handle arbitrary types, of course. Nor can it handle any functions or modules created in local scopes. (But those are quite unusual.)

Even so, with a bit of work I could see this maybe being robust enough for general use? We'd probably need to test it quite carefully...


One thing to note is that this isn't using the PyTree abstraction at all -- it's just recursing manually. I think this is probably the correct thing to do, actually. In general PyTrees don't try to offer anything like (de)serialisation rules.

cottrell commented 1 year ago

Thanks, ok I'll probably kick it along and try to make it more sane with your suggestions. The dict/module to special dict is nice.

I guess like you said with the caviat that this would only be valid for "vanilla" modules with standard types. Someone could do custom stuff I suppose and fallback to pickle but I typically stay away from that kind of thing for any kind of persistance.

cottrell commented 1 year ago

Incoporated all the suggestions I think except for functions serialization.

Haven't change the names or anything yet.

Do you have a list of test modules somewhere? I mean a set of equinox.Module instances that you test against.

import dataclasses
import importlib

import equinox as eqx
import jax
import jax.numpy as jnp

# NOTE: see https://github.com/patrick-kidger/equinox/issues/535

def recurse_get_state(x):
    # TODO: consider functions?
    if isinstance(x, eqx.Module):
        return {'module': {(x.__class__.__module__, x.__class__.__qualname__): recurse_get_state(x.__getstate__())}}
    elif isinstance(x, dict):
        return {'dict': {k: recurse_get_state(v) for k, v in x.items()}}
    elif isinstance(x, list):
        return [recurse_get_state(v) for v in x]
    elif isinstance(x, tuple):
        return tuple(recurse_get_state(v) for v in x)
    else:
        return x

def init_from_state_params(class_, params):
    module = object.__new__(class_)
    fieldnames = {f.name for f in dataclasses.fields(class_)}
    assert set(params.keys()) == fieldnames
    for key, value in params.items():
        object.__setattr__(module, key, value)
    return module

def get_object_from_module_and_qualname(module_name, qualname):
    module = importlib.import_module(module_name)
    obj = module
    for attr in qualname.split('.'):
        obj = getattr(obj, attr)
    return obj

def reconstitute_from_root(params):
    out = None
    if isinstance(params, dict):
        assert len(params) == 1
        k, v = list(params.items())[0]
        if k == 'module':
            module, qualname = list(v.keys())[0]
            class_ = get_object_from_module_and_qualname(module, qualname)
            params_ = reconstitute_from_root(list(v.values())[0])
            out = init_from_state_params(class_, params_)
        elif k == 'dict':
            out = {k_: reconstitute_from_root(v_) for k_, v_ in v.items()}
        else:
            raise Exception(f'unknown key {k}')
    elif isinstance(params, list):
        out = [reconstitute_from_root(v) for v in params]
    elif isinstance(params, tuple):
        out = tuple(reconstitute_from_root(v) for v in params)
    else:
        out = params
    return out

def reconstitute(params):
    module = reconstitute_from_root(params)
    return module
    if len(module) == 1:
        return module[list(module.keys())[0]]

def check_identical(tree1, tree2):
    def compare_elements(x, y):
        return jnp.all(x == y)

    comparison_tree = jax.tree_map(compare_elements, tree1, tree2)

    return all(jax.tree_util.tree_flatten(comparison_tree)[0])

class Linear(eqx.Module):
    weight: jax.Array
    bias: jax.Array

    def __init__(self, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size,))

class Another(eqx.Module):
    layers: list

    def __init__(self, n, in_size, out_size, key):
        self.layers = [Linear(in_size, out_size, key) for _ in range(n)]

def example():
    key = jax.random.PRNGKey(0)
    in_size = 12
    out_size = 3
    n = 5
    model = Another(n, in_size, out_size, key)
    params = recurse_get_state(model)
    model_ = reconstitute(params)
    print(f'check_identical={check_identical(model, model_)}')
    return model, model_
patrick-kidger commented 1 year ago

Not for this problem, no!

But off the top of my head, I'd probably suggest testing some of the linear solvers from Lineax, the diffeq solvers from Diffrax, and some of the more unusual corners of Equinox, like stateful layers and shared layers.

cottrell commented 1 year ago

@patrick-kidger

Have added a bunch of tests and some hacky serialization thing because I need that. But this is probably orthogonal to this stuff.

There are a number of ugly workarounds in there. But for basic stuff it was quite easy.

if you know how to easily your serializer to handle these things that would be neat. I had a look around but it doesn't seem like jax is very helpful ... can one even serialize a PyTreeDef? It seems like that should be the target.

I've put it in a gist as that might be easier but can paste it here if that is better.

https://gist.github.com/cottrell/f3d78b27a9dcd9d47dd7fd74f1841ab1

cottrell commented 1 year ago

Incidentally, if something doesn't already exist, the recursive get state or similar should be some kind of repr ... it seems a bit hard to see what state the modules are in.

patrick-kidger commented 1 year ago

Nice! Indeed once can't serialise a PyTreeDef so far as I know. As such I think the correct approach here is to ignore the PyTree structure and do de/serialisation without using that.

I can see this is starting to get quite involved. In particular the multiple different serde options (cloudpickle etc.) are probably going to start getting to be a bit too much. Moreover, an additional complexity arises with handling sharding and device placement of JAX arrays -- we should probably handle that correctly too.

As such I'm thinking this is starting to grow a bit beyond what really makes sense to include in Equinox. As it happens there's another library, Orbax, which is specifically dedicated to checkpointing. I'm not sure exactly what kind of PyTree serialisation they have, but a cleaned-up version of what you've got may be a good fit over there?

cottrell commented 1 year ago

Will check out orbax gradually and report back if it makes sense. An opinionated take on serialization is desperately needed in the jax ecosystem.