Open GallagherCommaJack opened 2 years ago
One major advantage of the flax approach vs eg running tree_flatten
and saving the whole checkpoint in a single directory is that it becomes very easy to load subtrees of the state, which is handy if eg you don't have enough memory to load the optimizer state but can afford to load the params.
Hmm, interesting question!
So Flax's state dicts and Equinox modules are essentially analogous. In each case they're both just PyTree of parameters. As such it should also be possible to serialise Equinox modules in custom ways very easily. For example:
class MyModel(eqx.Module):
submodule1: eqx.Module
submodule2: eqx.Module
model = MyModel(eqx.nn.Linear(...,), eqx.nn.Linear(...))
eqx.tree_serialise_leaves("sub1.eqx", model.submodule1)
eqx.tree_serialise_leaves("sub2.eqx", model.submodule2)
It sounds like you might be after some more automated way to do something like the above?
FWIW, I think it should already be possible to use eqx.tree_serialise_leaves
with tensorstore. Something like:
def filter_spec(f, x):
if isinstance(x, jnp.ndarray):
... # use tensorstore
else:
eqx.default_serialise_filter_spec(f, x)
eqx.tree_serialise_leaves("model.eqx", model, filter_spec)
WDYT?
tree_serialise_leaves
doesn't have any notion of the path within the pytree, so I don't see how this would work?
tensorstore checkpoints want to write each array to a separate directory
here's some example code doing this for FrozenDict
s, extracted + tweaked starting from the flax implementation
def _split_gdas(
target: Dict[str, Any]
) -> Tuple[Dict[str, Any], List[Tuple[Union[jax.Array, GlobalDeviceArray], str]]]:
# When target is a single leaf instead of a pytree dict.
if not isinstance(target, (core.FrozenDict, dict)):
if isinstance(target, jax.Array) or isinstance(target, GlobalDeviceArray):
return GDA_PH, [(target, "")]
return target, []
# Traverse the target and handle GlobalDeviceArrays.
flattened = traverse_util.flatten_dict(target, keep_empty_nodes=True)
gda_targets = []
for key, value in flattened.items():
if isinstance(value, jax.Array) or isinstance(value, GlobalDeviceArray):
subpath = "/".join(key)
gda_targets.append((value, subpath))
flattened[key] = GDA_PH + subpath
target = traverse_util.unflatten_dict(flattened)
return target, gda_targets
def save_commit(meta_path: Path, gda_path: Path, px_meta):
with tf.io.gfile.GFile(meta_path, "wb") as f:
f.write(fser.to_bytes(px_meta))
with tf.io.gfile.GFile(gda_path / "commit_success.txt", "w") as f:
f.write("success!")
def prep_gdas(prefix: Path, state: Dict[str, Any]):
state_meta, gda_list = _split_gdas(fser.to_state_dict(state))
gda_list = [
(
gda,
get_tensorstore_spec(str(prefix / subpath)),
)
for gda, subpath in gda_list
]
gda_list, tspecs = [list(x) for x in zip(*gda_list)]
return state_meta, gda_list, tspecs
def _raw_save_gda_checkpoint(
prefix: Union[str, Path],
state: Dict[str, Any],
manager: GlobalAsyncCheckpointManager,
gda_suffix="gda",
meta_suffix="meta",
):
prefix = Path(prefix)
state_meta, gda_list, tspecs = prep_gdas(prefix / gda_suffix, state)
manager.wait_until_finished()
manager.check_for_errors()
manager.serialize(
gda_list,
tspecs,
on_commit_callback=partial(
save_commit,
meta_path=prefix / meta_suffix,
gda_path=prefix / gda_suffix,
px_meta=state_meta,
),
)
Hi, I actually implemented basically this in my code base. Happy to upstream.
which is supported by code for key paths:
Hmm. FWIW I think if you want something like this - which is quite complicated - then this is probably out-of-scope for including in Equinox itself. (Especally as tensorstore is explicitly stated not to be stable API.)
I think it should be easy enough to make Equinox compatible with whatever you have in mind, though. Since all of Equinox's models are just PyTrees, then you should be able to just flatten them down and extract all the arrays:
### serialisation
hyperparams = ...
key = jax.random.PRNGKey(0)
model = eqx.nn.MLP(*hyperparams, key=key)
arrays = [leaf for leaf in jax.tree_util.tree_leaves(model) if eqx.is_array(leaf)]
arrays = {i: array for i, array in enumerate(arrays)}
# now serialise this dictionary however you like
### deserialisation
# let's deserialise just the 0th and 5th array
load_arrays = {0: ..., 5: ...}
key = jax.random.PRNGKey(0)
load_keys, load_values = zip(*load_arrays.items())
model = eqx.filter_eval_shape(lambda k: eqx.nn.MLP(*hyperparams, key=k), key)
def get_arrays(m):
arrays = [leaf for leaf in jax.tree_util.tree_leaves(m) if isinstance(leaf, jax.ShapeDtypeStruct)]
return [arrays[i] for i in load_keys]
model = eqx.tree_at(get_arrays, model, load_values)
Which in particular uses filter_eval_shape
as a trick to generate the appropriate PyTree structure without actually initialising any arrays. (Which might take up memory.)
FWIW I don't think the above code is perfect as filter_eval_shape
will currently promote Python float
s etc. into ShapeDtypeStruct
s rather than leaving them as Python builtins: filter_eval_shape(lambda: 1)
returns ShapeDtypeStruct(shape=(), dtype=int32)
. I'll submit a fix for that shortly...
yeah the only thing I don't like about the tree_flatten
approach is that all the path information is lost, so if afterwards I want to eg just load the params & not the optimizer state I'm entirely out of luck
That's still doable. Adapting the previous example, to load only the weight matrix of the MLP:
leaves = jax.tree_util.tree_leaves
### tweaked serialisation from above
hyperparams = ...
key = jax.random.PRNGKey(0)
model = eqx.nn.MLP(*hyperparams, key=key)
arrays = {i: leaf for i, leaf in enumerate(leaves(model)) if eqx.is_array(leaf)}
# now serialise this dictionary however you like
# find the index of the weight matrix
spec = jax.tree_util.tree_map(lambda _: False, model)
spec = eqx.tree_at(lambda m: m.weight, spec, True)
indices = [i for i, (m, s) in enumerate(zip(leaves(model), leaves(spec)))
if eqx.is_array(m) and s]
load_arrays = {i: ... for i in indices} # now deserialise as above.
Probably you can tweak/tidy-up/bugfix my sketchcode if needed.
The key point here is that "pytree structure" is really a notion that only makes sense within the context of a particular Python runtime, and it's probably a mistake to try and serialise that to disk. (c.f. all the difficulties that pickle
faces here.)
The key point here is that "pytree structure" is really a notion that only makes sense within the context of a particular Python runtime, and it's probably a mistake to try and serialise that to disk. (c.f. all the difficulties that pickle faces here.)
I'm not so sure about that - with nested dicts there's certainly something more canonical, and the same would be true of anything else that could in-principle be represented by an algebraic data type.
I'd like to use equinox for some fairly large-scale training runs, but the state for those models is often too large to fit on a single accelerator, so gathering all the state to serialize with numpy is far from ideal.
Also, checkpoints can be large, and saving them can take a while, so async support is nice to have.
Jax provides a low-level serialization API for individual arrays which should be perfect here, but making it work nicely with arbitrary equinox modules is going to be nontrivial.
Flax creates a directory tree based on the nested structure of their state dicts, maybe something similar could be done with equinox modules?