patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.1k stars 142 forks source link

support tensorstore checkpoints using gda_serialization api #216

Open GallagherCommaJack opened 2 years ago

GallagherCommaJack commented 2 years ago

I'd like to use equinox for some fairly large-scale training runs, but the state for those models is often too large to fit on a single accelerator, so gathering all the state to serialize with numpy is far from ideal.

Also, checkpoints can be large, and saving them can take a while, so async support is nice to have.

Jax provides a low-level serialization API for individual arrays which should be perfect here, but making it work nicely with arbitrary equinox modules is going to be nontrivial.

Flax creates a directory tree based on the nested structure of their state dicts, maybe something similar could be done with equinox modules?

GallagherCommaJack commented 2 years ago

One major advantage of the flax approach vs eg running tree_flatten and saving the whole checkpoint in a single directory is that it becomes very easy to load subtrees of the state, which is handy if eg you don't have enough memory to load the optimizer state but can afford to load the params.

patrick-kidger commented 2 years ago

Hmm, interesting question!

So Flax's state dicts and Equinox modules are essentially analogous. In each case they're both just PyTree of parameters. As such it should also be possible to serialise Equinox modules in custom ways very easily. For example:

class MyModel(eqx.Module):
    submodule1: eqx.Module
    submodule2: eqx.Module

model = MyModel(eqx.nn.Linear(...,), eqx.nn.Linear(...))
eqx.tree_serialise_leaves("sub1.eqx", model.submodule1)
eqx.tree_serialise_leaves("sub2.eqx", model.submodule2)

It sounds like you might be after some more automated way to do something like the above?

FWIW, I think it should already be possible to use eqx.tree_serialise_leaves with tensorstore. Something like:

def filter_spec(f, x):
    if isinstance(x, jnp.ndarray):
        ... # use tensorstore
    else:
        eqx.default_serialise_filter_spec(f, x)

eqx.tree_serialise_leaves("model.eqx", model, filter_spec)

WDYT?

GallagherCommaJack commented 2 years ago

tree_serialise_leaves doesn't have any notion of the path within the pytree, so I don't see how this would work?

GallagherCommaJack commented 2 years ago

tensorstore checkpoints want to write each array to a separate directory

GallagherCommaJack commented 2 years ago

here's some example code doing this for FrozenDicts, extracted + tweaked starting from the flax implementation

def _split_gdas(
    target: Dict[str, Any]
) -> Tuple[Dict[str, Any], List[Tuple[Union[jax.Array, GlobalDeviceArray], str]]]:
    # When target is a single leaf instead of a pytree dict.
    if not isinstance(target, (core.FrozenDict, dict)):
        if isinstance(target, jax.Array) or isinstance(target, GlobalDeviceArray):
            return GDA_PH, [(target, "")]
        return target, []
    # Traverse the target and handle GlobalDeviceArrays.
    flattened = traverse_util.flatten_dict(target, keep_empty_nodes=True)
    gda_targets = []
    for key, value in flattened.items():
        if isinstance(value, jax.Array) or isinstance(value, GlobalDeviceArray):
            subpath = "/".join(key)
            gda_targets.append((value, subpath))
            flattened[key] = GDA_PH + subpath
    target = traverse_util.unflatten_dict(flattened)
    return target, gda_targets

def save_commit(meta_path: Path, gda_path: Path, px_meta):
    with tf.io.gfile.GFile(meta_path, "wb") as f:
        f.write(fser.to_bytes(px_meta))
    with tf.io.gfile.GFile(gda_path / "commit_success.txt", "w") as f:
        f.write("success!")

def prep_gdas(prefix: Path, state: Dict[str, Any]):
    state_meta, gda_list = _split_gdas(fser.to_state_dict(state))
    gda_list = [
        (
            gda,
            get_tensorstore_spec(str(prefix / subpath)),
        )
        for gda, subpath in gda_list
    ]
    gda_list, tspecs = [list(x) for x in zip(*gda_list)]
    return state_meta, gda_list, tspecs

def _raw_save_gda_checkpoint(
    prefix: Union[str, Path],
    state: Dict[str, Any],
    manager: GlobalAsyncCheckpointManager,
    gda_suffix="gda",
    meta_suffix="meta",
):
    prefix = Path(prefix)
    state_meta, gda_list, tspecs = prep_gdas(prefix / gda_suffix, state)
    manager.wait_until_finished()
    manager.check_for_errors()
    manager.serialize(
        gda_list,
        tspecs,
        on_commit_callback=partial(
            save_commit,
            meta_path=prefix / meta_suffix,
            gda_path=prefix / gda_suffix,
            px_meta=state_meta,
        ),
    )
dlwh commented 2 years ago

Hi, I actually implemented basically this in my code base. Happy to upstream.

https://github.com/stanford-crfm/levanter/blob/fad60aa08efb28566ac4099b15ef75e345fb59e9/src/levanter/tensorstore_serialization.py#L23

which is supported by code for key paths:

https://github.com/stanford-crfm/levanter/blob/fad60aa08efb28566ac4099b15ef75e345fb59e9/src/levanter/jax_utils.py#L227

patrick-kidger commented 2 years ago

Hmm. FWIW I think if you want something like this - which is quite complicated - then this is probably out-of-scope for including in Equinox itself. (Especally as tensorstore is explicitly stated not to be stable API.)

I think it should be easy enough to make Equinox compatible with whatever you have in mind, though. Since all of Equinox's models are just PyTrees, then you should be able to just flatten them down and extract all the arrays:

### serialisation
hyperparams = ...
key = jax.random.PRNGKey(0)
model = eqx.nn.MLP(*hyperparams, key=key)
arrays = [leaf for leaf in jax.tree_util.tree_leaves(model) if eqx.is_array(leaf)]
arrays = {i: array for i, array in enumerate(arrays)}
# now serialise this dictionary however you like

### deserialisation
# let's deserialise just the 0th and 5th array
load_arrays = {0: ..., 5: ...}
key = jax.random.PRNGKey(0)
load_keys, load_values = zip(*load_arrays.items())
model = eqx.filter_eval_shape(lambda k: eqx.nn.MLP(*hyperparams, key=k), key)
def get_arrays(m):
    arrays = [leaf for leaf in jax.tree_util.tree_leaves(m) if isinstance(leaf, jax.ShapeDtypeStruct)]
    return [arrays[i] for i in load_keys]
model = eqx.tree_at(get_arrays, model, load_values)

Which in particular uses filter_eval_shape as a trick to generate the appropriate PyTree structure without actually initialising any arrays. (Which might take up memory.)

patrick-kidger commented 2 years ago

FWIW I don't think the above code is perfect as filter_eval_shape will currently promote Python floats etc. into ShapeDtypeStructs rather than leaving them as Python builtins: filter_eval_shape(lambda: 1) returns ShapeDtypeStruct(shape=(), dtype=int32). I'll submit a fix for that shortly...

GallagherCommaJack commented 2 years ago

yeah the only thing I don't like about the tree_flatten approach is that all the path information is lost, so if afterwards I want to eg just load the params & not the optimizer state I'm entirely out of luck

patrick-kidger commented 2 years ago

That's still doable. Adapting the previous example, to load only the weight matrix of the MLP:

leaves = jax.tree_util.tree_leaves

### tweaked serialisation from above
hyperparams = ...
key = jax.random.PRNGKey(0)
model = eqx.nn.MLP(*hyperparams, key=key)
arrays = {i: leaf for i, leaf in enumerate(leaves(model)) if eqx.is_array(leaf)}
# now serialise this dictionary however you like

# find the index of the weight matrix
spec = jax.tree_util.tree_map(lambda _: False, model)
spec = eqx.tree_at(lambda m: m.weight, spec, True)
indices = [i for i, (m, s) in enumerate(zip(leaves(model), leaves(spec)))
              if eqx.is_array(m) and s]
load_arrays = {i: ... for i in indices}  # now deserialise as above.

Probably you can tweak/tidy-up/bugfix my sketchcode if needed.

The key point here is that "pytree structure" is really a notion that only makes sense within the context of a particular Python runtime, and it's probably a mistake to try and serialise that to disk. (c.f. all the difficulties that pickle faces here.)

GallagherCommaJack commented 2 years ago

The key point here is that "pytree structure" is really a notion that only makes sense within the context of a particular Python runtime, and it's probably a mistake to try and serialise that to disk. (c.f. all the difficulties that pickle faces here.)

I'm not so sure about that - with nested dicts there's certainly something more canonical, and the same would be true of anything else that could in-principle be represented by an algebraic data type.