Closed patrick-kidger closed 2 years ago
I would like to pick up on a comment you made in https://www.reddit.com/r/MachineLearning/comments/u34oh2/d_what_jax_nn_library_to_use/i4umg44/?context=3.
The reason I'm dragging my heels on this is that I'm not yet completely sold on whether to use the code above, or to instead try and pickle the entire PyTree in one go. The former requires you to be able to produce the PyTree structure yourself (and by default doesn't save e.g. NumPy arrays); the latter has annoying compatibility concerns. (+I think Flax has a couple of ways of doing de/serialisation that might be worth using as inspiration.) So I want to be sure we get this right.
I agree that it would be attractive to have an option to serialize complete PyTrees, but I also see some challenges. Just to give an example, if you have a equinox module,
class MyModule(eqx.Module):
activation: Any
module = MyModule(activation=jnn.tanh)
pickling this will fail because jnn.tanh seems to point to a lambda and pickle doesn't support pickling lambdas (https://stackoverflow.com/questions/25348532/can-python-pickle-lambda-functions).
But until we have a solution for this, why not go with something along the lines of
def save_model_weights(model: eqx.Module, path: str):
tree_weights, tree_other = eqx.partition(model, eqx.is_inexact_array)
with open(path + "_weights.npy", "wb") as f:
for x in jax.tree_leaves(tree_weights):
np.save(f, x, allow_pickle=False)
def load_model_weights(model, path):
tree_weights, tree_other = eqx.partition(model, eqx.is_inexact_array)
leaves_orig, treedef = jax.tree_flatten(tree_weights)
with open(path + "_weights.npy", "rb") as f:
flat_state = [jnp.asarray(np.load(f)) for _ in leaves_orig]
tree_weights = jax.tree_unflatten(treedef, flat_state)
return eqx.combine(tree_weights, tree_other)
The naming should make it clear that it's just about the model weights. Once we have a good solution for pickling entire trees, we can add methods
def save_model(model: eqx.Module, path: str):
...
def load_model(path):
...
Right, pickling is a whole can of worms. Let's leave that alone for now.
Yep, you've convinced me! I think it seems reasonable to add something that saves just the leaves of a PyTree.
Your implementation is also a fair bit tidier than the one I suggested over on reddit; nice. Can you open a PR for this? I'll pre-emptively offer a review - mostly nits.
{save,load}_leaves
or {de,}serialise_leaves
or similar, to help emphasise this.is_inexact_array
. And perhaps make the default is_array
instead?jnp.{load,save}
over np.{load,save}
. I think this is needed to handle bfloat16
dtypes correctly.path
to be a Union[str, pathlib.Path]
. I also probably wouldn't hardcode the + "_weights.npy"
part of it; I think just pass path
through unchanged. (Or only add a .npy
suffix.)flat_state
and leaves_orig
match shape+dtype, and raise an error if not.Thought. We may also wish to allow customising the use of jnp.{load,save}
at some point.
For example: whilst it hasn't happened yet, there's discussion about eventually being able to de/serialise JIT-compiled functions in JAX. Meanwhile, the next release of Equinox will have filter_jit
return a new _JitWrapper
object: instead of capturing its inputs (function, filter spec, etc.) via closure, it will capture them as leaves in a PyTree instead. Mostly that doesn't change anything, but in this context it offers a neat opportunity. We could de/serialise a filter-jit-wrapped function in the same way as everything else; we'd just need to use whatever mechanism is introduced for de/serialising JIT-compiled functions in place of jnp.{load,save}
.
Hi, thanks for the review and the thoughtful remarks! I agree with most of your comments, but I'm not entirely sold on exposing a function serialize_leaves
. There are two reasons
a) the naming seems to imply that it can serialize the leaves of an arbitrary tree. In practice, you would be limited to serializing numeric leaves due to the use of jnp.save
, so as a user you somehow have to be aware of what should probably be an implementation detail.
b) the predominant use case will probably saving weights of a model. At least, that's what I would look for as a user of a neural network library and the connection between a method serialize_leaves
and serializing weights might not be immediately clear to a new user.
So I would rather make _serialize_leaves
a library internal method and expose serialize_model_weights
to the user (which internally would call _serialize_leaves
).
What do you think?
Hmm. I agree that the points you're making are reasonable.
If we were to use a tree-based name:
filter_spec
specify the saving/loading function, so that it really is this more general thing. For example filter_spec=lambda x: jnp.save if is_array(x) else None
. (Valid values being either callables or None
.)apply_updates
and tree_at
. Perhaps we could split out the above page into multiple pages:
Utilities
├ Manipulation
│ ├ apply_updates # hmm, this one doesn't fit the consistent naming scheme.
│ ├ tree_at
│ └ tree_inference # new in the next update, don't worry if you don't recognise this ;)
├ Serialisation
│ ├ tree_serialise_leaves
│ └ tree_deserialise_leaves
└ Miscellaneous
├ tree_pformat
├ tree_equal
└ static_field # This should really go with Module but it's a pretty advanced
# thing with niche uses, so we hide it here instead.
Experimental
└ Stateful operations
├ StateIndex
├ get_state
└ set_state
As you can probably tell, I'm disinclined to write something as specific save_model_weights
. It runs counter to the simple models-are-PyTrees idea that is what makes reasoning about Equinox so easy in the first place. I don't think de/serialisation is an important and special enough problem that it's worth breaking that abstraction.
My feeling is that issues like new-user-discoverability can be handled through appropriate documentation etc. Perhaps we could change the proposed heading above from "Serialisation" to "Serialisation (save/load models to disk)" if we wanted to really make it clear where to look.
WDYT?
Hi, sorry for the late feedback.
As you can probably tell, I'm disinclined to write something as specific
save_model_weights
. It runs counter to the simple models-are-PyTrees idea that is what makes reasoning about Equinox so easy in the first place. I don't think de/serialisation is an important and special enough problem that it's worth breaking that abstraction.
I can understand your reasoning here.
With respect to (b), I still think that the best documentation is the documentation that you do not need, so one alternative I could think of is introducing methods save_weights
and load_weights
in eqx.Module
which are a one-liner call to tree_serialise_leaves
. This would of course increase the amount of code in eqx.Module
, but I think that the one-liner would very naturally embody the models-are-PyTrees idea.
Would you be okay with that?
Irrespective of how we do it, I'd be happy to contribute a first draft of a PR sometime this week.
Is your reasoning that this would improve discoverability because it'll be in the documentation for Module
? (Or because it'll appear in dir(Module())
?)
At least wrt the former then this could be handled by improving the docstring and I'd be happy to have that in. The problem with adding new methods is that this breaks a different principle in Eqx, which is that no Module method is special cased. (Other than magic methods and tree_{un,}flatten
.)
(I appreciate it probably feels like you're threading a needle here.)
Is your reasoning that this would improve discoverability because it'll be in the documentation for Module? (Or because it'll appear in dir(Module())?)
Yes, exactly. It seems like an obvious place where to look as a user.
At least wrt the former then this could be handled by improving the docstring and I'd be happy to have that in. The problem with adding new methods is that this breaks a different principle in Eqx, which is that no Module method is special cased. (Other than magic methods and tree_{un,}flatten.)
Could you elaborate on what you mean that no Module method is special cased?
At least wrt the former then this could be handled by improving the docstring and I'd be happy to have that in. The problem with adding new methods is that this breaks a different principle in Eqx, which is that no Module method is special cased. (Other than magic methods and tree_{un,}flatten.)
(I appreciate it probably feels like you're threading a needle here.)
No worries at all, I think it is good if library authors are opinionated about their design choices! I appreciate the discussion. Edit: not sure if opinionated carries a negative connotation - if so, it wasn't meant in that way!
Yes, exactly. It seems like an obvious place where to look as a user.
FWIW I do feel that discoverability could be better in Equinox. For example tree_at
is pretty important for modifying PyTrees, but it's not terribly well-advertised.
Could you elaborate on what you mean that no Module method is special cased?
Consider magic methods: these are special cased by Python itself. MyClass.__len__
has meaning beyond simply being a method on a class, and changing magic methods allow you to change how Python handles your class. Likewise there is tree_{un,}flatten
, and these allow you to handle how JAX handles your class.
(Another example: PyTorch special-cases forward
, as the appropriate extension point for subclasses of torch.nn.Module
.)
Equinox doesn't special-case any methods like this. You can't change how Equinox treats your class. (You don't need to.)
Edit: not sure if opinionated carries a negative connotation - if so, it wasn't meant in that way!
Not at all! Not how I read it.
Also: this should respect eqx.experimental.{StateIndex,get_state,set_state}
. Ideally the default should also handle NumPy arrays.
The following is entirely untested, but I think that probably means doing something like:
def _save_index(f, x: StateIndex):
value = experimental.get_state(x)
jnp.save(value)
def _save_no(f, x):
pass
def _load_index(f, x: StateIndex):
value = jnp.load(f)
experimental.set_state(x, value)
return x
def _load_scalar(f, x):
return np.load(f).item()
def _load_no(f, x):
return x
def _default_serialise_filter_spec(x):
if is_array_like(x):
return jnp.save
elif isinstance(x, experimental.StateIndex):
return _save_index
else:
return _save_no
def _default_deserialise_filter_spec(x):
if isinstance(x, jnp.ndarray):
return jnp.load
elif isinstance(x, np.ndarray):
return np.load
elif isinstance(x, (bool, float, complex, int)):
return _load_scalar
elif isinstance(x, experimental.StateIndex):
return _load_index
else:
return _load_no
def _assert_same(new, old):
if type(new) is not type(old):
raise RuntimeError(...)
if isinstance(new, (np.ndarray, jnp.array)) and (new.shape != old.shape or new.dtype != old.dtype):
raise RuntimeError(...)
def _is_index(x):
return isinstance(x, experimental.StateIndex)
def tree_serialise_leaves(path: Union[str, pathlib.Path], pytree: PyTree, filter_spec=_default_serialise_filter_spec, is_leaf=_is_index):
with open(pathlib.Path(path).with_suffix("npy"), "wb") as f:
def _serialise(spec, x):
def __serialise(y):
spec(f, y)
return jax.tree_map(__serialise, x)
jax.tree_map(_serialise, filter_spec, pytree, is_leaf=is_leaf)
def tree_deserialise_leaves(path: Union[str, pathlib.Path], like: PyTree, filter_spec=_default_serialise_filter_spec, is_leaf=_is_index):
with open(pathlib.Path(path).with_suffix("npy"), "rb") as f:
def _deserialise(spec, x):
def __deserialise(y):
return spec(f, y)
return jax.tree_map(__deserialise, x)
out = jax.tree_map(_deserialise, filter_spec, like, is_leaf=is_leaf)
jax.tree_map(_assert_same, out, like, is_leaf=is_leaf)
return out
@jaschau heads-up when putting together your PR to branch off the v050
branch.
@jaschau Any plans to pick this up as a PR? No worries if not -- I'll do it -- I just want to get seralisation+deserialisation into the next release of Equinox.
Hi @patrick-kidger, sorry for the little progress here. I've been side-tracked with other topics so I haven't come around to working on this beyond our initial discussion. I'm afraid I cannot promise any serious progress on this in the upcoming weeks from my side, so please move ahead if you're eager to work on this.
Equinox models are just PyTrees so they should be very easy to serialise/deserialise; just save the PyTree to disk in whatever way is desired. It might be worth adding some library functions for this just for convenience. Perhaps checking the device of JAX arrays etc?
This should respect the
get_state
/set_state
stuff that's being put together.In addition, there should be a version of
get_state
which inlines its state in the call graph, for faster inference.