Open cottrell opened 1 year ago
So the thing is that you still need to store your forward pass somewhere. That's just some Python function, and you can basically do arbitrary things inside it. Unfortunately, Python doesn't really offer good ways to serialise arbitrary Python functions.
More broadly, your model might include quite a lot of things that aren't parameters (various flags, hyperparameters, etc.) Again, Python doesn't really offer a clean way to serialise arbitrary things like this. (pickle comes close but is fraught with edge-cases.)
The end result is basically that you need to still have access to the source code for say, eqx.nn.MLP
, in order to deserialise it.
This is no different in just JAX -- you can serialise your parameters just fine, but again you still need to record the "everything else".
I think you are misunderstanding. This is about the inputs to init which, if they were to be interpreted as code, would make for and incredibly inconvenient paradigm in the midst of what is an incredibly convenient library.
The code which defines the classes of course needs to be in scope and load time ... a bit like some unpickle a class thing. And of course we would need some convention for identifying keys that correspond to classes but possibly walking the tree from top to bottom and checking the scope would make this possible.
The alternative approach is to create a custom wrapper of init that puts the (unused) init params in the state and uses these on reconstitution merely to create an Module that will then have it's values replaced. This is probably easier but feels like there should be a better approach. This would effectively be like "you need the code" but without literally having new code all over the place just for init params.
Anway, I'll look around for how to get the full pytree out of a Module. It seems like most operations aren't by default recursing through all modules. I imagine you have a util for that somewhere?
You mean, you're asking to serialise the __init__
-time arguments alongside the saved parameters? The idea being that you can then load without an existing model to use a reference?
Anway, I'll look around for how to get the full pytree out of a Module
As for this -- the Module is already a pytree!
Wait ... isn't this (in examples) the pattern I'm asking about? I somehow didn't see that. Ok this makes more sense.
https://docs.kidger.site/equinox/examples/serialisation/
I'll post some interesting hack in a sec for REALLY going a bit into the rabbit hole of not even assembling the pytree ahead of time. Likely won't generalize.
Wait ... isn't this (in examples) the pattern I'm asking about? I somehow didn't see that. Ok this makes more sense.
Yup, exactly! ;)
I'll post some interesting hack in a sec for REALLY going a bit into the rabbit hole of not even assembling the pytree ahead of time. Likely won't generalize.
Haha, that sounds exciting. I look forward to seeing it.
This is probably a horror show, and I'm not super familiar with clever ways to do things in pytrees, but "it works" (maybe) as an example I think.
The only interesting idea is to "ban" over-riding init and use init instead so that you can actually reconstitute the modules directly from something.
I'm just doing some total hack to leave a trace/clue as to the module name in a dict key.
import equinox as eqx
import jax
import jax.numpy as jnp
def recurse_get_state(x):
if isinstance(x, eqx.Module):
# NOTE: fragile naming convention
return {f'module:{type(x).__name__}': recurse_get_state(x.__getstate__())}
elif isinstance(x, dict):
return {k: recurse_get_state(v) for k, v in x.items()}
elif isinstance(x, list):
return [recurse_get_state(v) for v in x]
elif isinstance(x, tuple):
return tuple(recurse_get_state(v) for v in x)
else:
return x
def reconstitute_from_root(scope, params):
out = None
if isinstance(params, dict):
if len(params) == 1:
k, v = list(params.items())[0]
if isinstance(k, str) and k.startswith('module:'):
assert len(params) == 1
name = k.split(':')[1]
class_ = scope[name]
params_ = reconstitute_from_root(scope, v)
out = class_(**params_)
else:
out = {k: reconstitute_from_root(scope, v) for k, v in params.items()}
else:
out = {k: reconstitute_from_root(scope, v) for k, v in params.items()}
elif isinstance(params, list):
out = [reconstitute_from_root(scope, v) for v in params]
elif isinstance(params, tuple):
out = tuple(reconstitute_from_root(scope, v) for v in params)
else:
out = params
return out
def reconstitute(scope, params):
module = reconstitute_from_root(scope, params)
return module
if len(module) == 1:
return module[list(module.keys())[0]]
def check_identical(tree1, tree2):
def compare_elements(x, y):
return jnp.all(x == y)
comparison_tree = jax.tree_map(compare_elements, tree1, tree2)
return all(jax.tree_util.tree_flatten(comparison_tree)[0])
class Linear(eqx.Module):
weight: jax.Array
bias: jax.Array
@classmethod
def init(cls, in_size, out_size, key):
wkey, bkey = jax.random.split(key)
weight = jax.random.normal(wkey, (out_size, in_size))
bias = jax.random.normal(bkey, (out_size,))
return cls(weight=weight, bias=bias)
class Another(eqx.Module):
layers: list
@classmethod
def init(cls, n, in_size, out_size, key):
layers = [Linear.init(in_size, out_size, key) for _ in range(n)]
return cls(layers=layers)
def example():
key = jax.random.PRNGKey(0)
in_size = 12
out_size = 3
n = 5
model = Another.init(n, in_size, out_size, key)
params = recurse_get_state(model)
model_ = reconstitute(globals(), params)
print(f'check_identical={check_identical(model, model_)}')
return model, model_
Haha, this actually isn't too bad. So the need for an init
method can be avoided by using the same trick that Equinox uses during tree-unflattening:
module = object.__new__(YourClass)
fieldnames = {f.name for f in dataclasses.fields(YourClass)}
assert set(params.keys() == fieldnames
for key, value in params.items():
object.__setattr__(module, key, value)
The "module:"
fragility can be handled by also nesting dictionaries with one more layer of dictionaries. (So you serialise a dictionary as {"dict": the_actual_dict}
, and a module as {"module": ...}
).
You can probably tackle deserialisation by storing both module.__module__
and module.__qualname__
to do lookup. You can probably also serialise functions in the same way.
This still won't be able to handle arbitrary types, of course. Nor can it handle any functions or modules created in local scopes. (But those are quite unusual.)
Even so, with a bit of work I could see this maybe being robust enough for general use? We'd probably need to test it quite carefully...
One thing to note is that this isn't using the PyTree abstraction at all -- it's just recursing manually. I think this is probably the correct thing to do, actually. In general PyTrees don't try to offer anything like (de)serialisation rules.
Thanks, ok I'll probably kick it along and try to make it more sane with your suggestions. The dict/module to special dict is nice.
I guess like you said with the caviat that this would only be valid for "vanilla" modules with standard types. Someone could do custom stuff I suppose and fallback to pickle but I typically stay away from that kind of thing for any kind of persistance.
Incoporated all the suggestions I think except for functions serialization.
Haven't change the names or anything yet.
Do you have a list of test modules somewhere? I mean a set of equinox.Module instances that you test against.
import dataclasses
import importlib
import equinox as eqx
import jax
import jax.numpy as jnp
# NOTE: see https://github.com/patrick-kidger/equinox/issues/535
def recurse_get_state(x):
# TODO: consider functions?
if isinstance(x, eqx.Module):
return {'module': {(x.__class__.__module__, x.__class__.__qualname__): recurse_get_state(x.__getstate__())}}
elif isinstance(x, dict):
return {'dict': {k: recurse_get_state(v) for k, v in x.items()}}
elif isinstance(x, list):
return [recurse_get_state(v) for v in x]
elif isinstance(x, tuple):
return tuple(recurse_get_state(v) for v in x)
else:
return x
def init_from_state_params(class_, params):
module = object.__new__(class_)
fieldnames = {f.name for f in dataclasses.fields(class_)}
assert set(params.keys()) == fieldnames
for key, value in params.items():
object.__setattr__(module, key, value)
return module
def get_object_from_module_and_qualname(module_name, qualname):
module = importlib.import_module(module_name)
obj = module
for attr in qualname.split('.'):
obj = getattr(obj, attr)
return obj
def reconstitute_from_root(params):
out = None
if isinstance(params, dict):
assert len(params) == 1
k, v = list(params.items())[0]
if k == 'module':
module, qualname = list(v.keys())[0]
class_ = get_object_from_module_and_qualname(module, qualname)
params_ = reconstitute_from_root(list(v.values())[0])
out = init_from_state_params(class_, params_)
elif k == 'dict':
out = {k_: reconstitute_from_root(v_) for k_, v_ in v.items()}
else:
raise Exception(f'unknown key {k}')
elif isinstance(params, list):
out = [reconstitute_from_root(v) for v in params]
elif isinstance(params, tuple):
out = tuple(reconstitute_from_root(v) for v in params)
else:
out = params
return out
def reconstitute(params):
module = reconstitute_from_root(params)
return module
if len(module) == 1:
return module[list(module.keys())[0]]
def check_identical(tree1, tree2):
def compare_elements(x, y):
return jnp.all(x == y)
comparison_tree = jax.tree_map(compare_elements, tree1, tree2)
return all(jax.tree_util.tree_flatten(comparison_tree)[0])
class Linear(eqx.Module):
weight: jax.Array
bias: jax.Array
def __init__(self, in_size, out_size, key):
wkey, bkey = jax.random.split(key)
self.weight = jax.random.normal(wkey, (out_size, in_size))
self.bias = jax.random.normal(bkey, (out_size,))
class Another(eqx.Module):
layers: list
def __init__(self, n, in_size, out_size, key):
self.layers = [Linear(in_size, out_size, key) for _ in range(n)]
def example():
key = jax.random.PRNGKey(0)
in_size = 12
out_size = 3
n = 5
model = Another(n, in_size, out_size, key)
params = recurse_get_state(model)
model_ = reconstitute(params)
print(f'check_identical={check_identical(model, model_)}')
return model, model_
Not for this problem, no!
But off the top of my head, I'd probably suggest testing some of the linear solvers from Lineax, the diffeq solvers from Diffrax, and some of the more unusual corners of Equinox, like stateful layers and shared layers.
@patrick-kidger
Have added a bunch of tests and some hacky serialization thing because I need that. But this is probably orthogonal to this stuff.
There are a number of ugly workarounds in there. But for basic stuff it was quite easy.
if you know how to easily your serializer to handle these things that would be neat. I had a look around but it doesn't seem like jax is very helpful ... can one even serialize a PyTreeDef? It seems like that should be the target.
I've put it in a gist as that might be easier but can paste it here if that is better.
https://gist.github.com/cottrell/f3d78b27a9dcd9d47dd7fd74f1841ab1
Incidentally, if something doesn't already exist, the recursive get state or similar should be some kind of repr ... it seems a bit hard to see what state the modules are in.
Nice! Indeed once can't serialise a PyTreeDef so far as I know. As such I think the correct approach here is to ignore the PyTree structure and do de/serialisation without using that.
I can see this is starting to get quite involved. In particular the multiple different serde options (cloudpickle etc.) are probably going to start getting to be a bit too much. Moreover, an additional complexity arises with handling sharding and device placement of JAX arrays -- we should probably handle that correctly too.
As such I'm thinking this is starting to grow a bit beyond what really makes sense to include in Equinox. As it happens there's another library, Orbax, which is specifically dedicated to checkpointing. I'm not sure exactly what kind of PyTree serialisation they have, but a cleaned-up version of what you've got may be a good fit over there?
Will check out orbax gradually and report back if it makes sense. An opinionated take on serialization is desperately needed in the jax ecosystem.
Woudn't it make sense to add something to wrap all this to make deserialization work? Currently it's lossy and you need to basically either define a init_param_from_params function or store the init params separately somewhere.
I might be missing something included elsewhere.
And what I mean is that
eqx.nn.MLP(2, 2, 2, 2,
should not be needed to "deseralize". For example in pure jax you write everything based on params and I simply have jsonifiers of params and it all works. The "init_params" are not needed.A special classmethod might make sense here. I'm not deep enough into the internals of equinox yet. But the philosophy sounds like it should all separate and merely be a nice way of organizing jax params.
might make sense. I'm not deep enough into the
https://docs.kidger.site/equinox/api/serialisation/