patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
1.93k stars 132 forks source link

Initialization of large models on multi-hosts environment #778

Open kazewong opened 1 week ago

kazewong commented 1 week ago

Hi all,

I am wondering what is the preferred way to create a model that is too large to fit in a single device

As a reference starting point, if I use data parallelism, I will first create per-device data arrays, and use make_array from single_device_arrays to put them on the global mesh (This is basically following https://jax.readthedocs.io/en/latest/_autosummary/jax.make_array_from_single_device_arrays.html#jax.make_array_from_single_device_arrays)

Since by default the pre-defined modules in equinox will initialize the full set of parameters on every device, I cannot just follow the guide in https://docs.kidger.site/equinox/tricks/#custom-parameter-initialisation and update the parameters to the shared version after I create the model.

My current way of bypassing this is to create a wrapper class of the nn.Modules I want to use, so I can create the sharded version of the parameters on each device, and then combine them as I would for the data parallelism case.

Here is a minimal example for wrapping the eqx.nn.Linear class https://gist.github.com/kazewong/c976b48c5870d866496740341382acb5

Since the multi-host interface in Jax is still experimental, we probably don't want to put too much of this into the equinox core code. To make creating large models easier now, I think a wrapper class or a decorator is probably the easiest way, but I want to see what people think about this before submitting a PR. @patrick-kidger

patrick-kidger commented 1 week ago

So one way to do this with the current API is to create the model 'skeleton' via

eqx.filter_eval_shape(SomeModule, ...)

and then fill in the parameters with eqx.tree_at afterwards.

I do this when translating model weights from PyTorch, for example.

I'm not sure what the cleanest way of doing this in general is. Ideally it wouldn't require wrapping the constructors. These can be thought of as functions that return pytrees, so something respecting that in some way would be ideal.

kazewong commented 1 week ago

I see! I tried using filter_eval_shape and it does work for me. Do you think it is worth making a short tutorial notebook for this? I have the script so converting it into a notebook shouldn't take too long. Let me know if this is useful then I can work on a PR.

patrick-kidger commented 1 week ago

If this ends up being the best way to do it, then yes! I think having an example would be really useful.

Right now I'm not completely convinced it is the best way, though. It's certainly a lot of work. I'm wondering if there's some way we can directly create the arrays in the right way by wrapping the constructor in a filtered/tree-map'd version of the usual JAX operations here. (make_array from single_device_arrays etc.) Right now I've simply not thought about it yet!

kazewong commented 3 days ago

I agree if there is an ergonomic way to initialize a shared model, that would be great. I am trying to figure out how to best do this, and here are some thoughts:

The ideal scenario is we have a function that does the following

def init_shard_model(ModelClass, mesh, sharding, ...) -> model:
...

I have so far run into two complications:

  1. Sharding internals of a model differently. Take a linear layer with 512 inputs and 256 outputs as an example, the weight matrix will have shape (256, 512) and the bias vector will have shape(256,). Say somehow I want to shard the weight matrix across its second axis, and the bias vector across its first axis, where should I give that information? Right now I am defining a sharding function for each kind of layer, so it is set there
  2. Sharding structured arrays. Say my model has an upper triangular matrix as its weight at initialization, and I want to shard it along one axis. This means when I construct the local arrays, I need to pass the information of the process IDs or some bookkeeping parameters to make sure the correct part is initialized on each host before I combine them.

I think this may be why the devs of Levanter rolled their own nn library haliax on top of equinox so they can handle these issues, which is essentially smart wrapping around the equinox layers? @dlwh

I am gonna try a number of things in the coming weeks, any input will be awesome!

dlwh commented 3 days ago

This is indeed one of the main reasons I did names. (There are a few more but I won't belabor them here.) It basically defines this problem away, though it does mean you have to do weird gymnastics when you have square arrays that have semantically identical axes. But it's a price I'm happy to pay. (Obviously happy to have you over in Haliax land, but only if it appeals.)

You could instead follow something more like what flax does, if you wanted. See for example, t5x. It's basically the same as Haliax except the names of the axes aren't carried around with the array and so you occasionally have to sprinkle more wsc's. Obviously I'm partial to Haliax but tastes vary.

Regardless, either way the basic idea is to do initialization inside jit and use with_sharding_constraint and/or out_shardings to ensure things are correct. IMHO re (2) the right way to do any of this is to always always do as much as you can in a "global" way using jit, and only fall back on make_array_from_callback when you absolutely have to (or for data loading).