Open kazewong opened 1 week ago
So one way to do this with the current API is to create the model 'skeleton' via
eqx.filter_eval_shape(SomeModule, ...)
and then fill in the parameters with eqx.tree_at
afterwards.
I do this when translating model weights from PyTorch, for example.
I'm not sure what the cleanest way of doing this in general is. Ideally it wouldn't require wrapping the constructors. These can be thought of as functions that return pytrees, so something respecting that in some way would be ideal.
I see! I tried using filter_eval_shape
and it does work for me. Do you think it is worth making a short tutorial notebook for this? I have the script so converting it into a notebook shouldn't take too long. Let me know if this is useful then I can work on a PR.
If this ends up being the best way to do it, then yes! I think having an example would be really useful.
Right now I'm not completely convinced it is the best way, though. It's certainly a lot of work. I'm wondering if there's some way we can directly create the arrays in the right way by wrapping the constructor in a filtered/tree-map'd version of the usual JAX operations here. (make_array from single_device_arrays
etc.) Right now I've simply not thought about it yet!
I agree if there is an ergonomic way to initialize a shared model, that would be great. I am trying to figure out how to best do this, and here are some thoughts:
The ideal scenario is we have a function that does the following
def init_shard_model(ModelClass, mesh, sharding, ...) -> model:
...
I have so far run into two complications:
I think this may be why the devs of Levanter
rolled their own nn library haliax
on top of equinox
so they can handle these issues, which is essentially smart wrapping around the equinox
layers? @dlwh
I am gonna try a number of things in the coming weeks, any input will be awesome!
This is indeed one of the main reasons I did names. (There are a few more but I won't belabor them here.) It basically defines this problem away, though it does mean you have to do weird gymnastics when you have square arrays that have semantically identical axes. But it's a price I'm happy to pay. (Obviously happy to have you over in Haliax land, but only if it appeals.)
You could instead follow something more like what flax does, if you wanted. See for example, t5x. It's basically the same as Haliax except the names of the axes aren't carried around with the array and so you occasionally have to sprinkle more wsc's. Obviously I'm partial to Haliax but tastes vary.
Regardless, either way the basic idea is to do initialization inside jit and use with_sharding_constraint and/or out_shardings to ensure things are correct. IMHO re (2) the right way to do any of this is to always always do as much as you can in a "global" way using jit, and only fall back on make_array_from_callback when you absolutely have to (or for data loading).
Hi all,
I am wondering what is the preferred way to create a model that is too large to fit in a single device
As a reference starting point, if I use data parallelism, I will first create per-device data arrays, and use make_array from single_device_arrays to put them on the global mesh (This is basically following https://jax.readthedocs.io/en/latest/_autosummary/jax.make_array_from_single_device_arrays.html#jax.make_array_from_single_device_arrays)
Since by default the pre-defined modules in
equinox
will initialize the full set of parameters on every device, I cannot just follow the guide in https://docs.kidger.site/equinox/tricks/#custom-parameter-initialisation and update the parameters to the shared version after I create the model.My current way of bypassing this is to create a wrapper class of the
nn.Modules
I want to use, so I can create the sharded version of the parameters on each device, and then combine them as I would for the data parallelism case.Here is a minimal example for wrapping the
eqx.nn.Linear
class https://gist.github.com/kazewong/c976b48c5870d866496740341382acb5Since the multi-host interface in Jax is still experimental, we probably don't want to put too much of this into the equinox core code. To make creating large models easier now, I think a wrapper class or a decorator is probably the easiest way, but I want to see what people think about this before submitting a PR. @patrick-kidger