stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
149 stars 11 forks source link

make Linear support overlapping input/axis names? #53

Open dlwh opened 9 months ago

dlwh commented 9 months ago

Currently Haliax requires that all names in a single named array be unique. In general I think this is a good constraint. However, for Linear layers it's frequently a nuisance, since one often projects to something of the same shape, or you might want to keep the same name ("hidden").

So, it might be a good idea to support overlapping names. This will complicate the implementation quite a bit but simplify some juggling outside. I think this is worth the complexity?

Probably we'd rename overlapping "output" names to ${name}_out and then rename them in the result back to ${name}. If we make this a contract, then you can use it to control sharding.

cooljoseph1 commented 1 month ago

I think it's easiest to just always rename all axes to ${name}_in for in axes and ${name}_out for out axes. (This guarantees there will be no conflicting names, since the in axes all end in "in" whereas the out axes all end in "out".) I've implemented that in the above pull request.

Is there a reason to not do this (e.g., performance issues)?

dlwh commented 1 month ago

It messes up FSDP, or at least it makes it so you have to specify that both Embed_in an Embed_out are sharded, which is a bit noisier

cooljoseph1 commented 1 month ago

I don't know how sharding works in Haliax. Would you mind explaining why it messes up sharding?

dlwh commented 4 weeks ago

well, "messes up" is a bit strong, but the key idea behind sharding in Haliax is mapping named axes to a device mesh axis (cf the tutorial https://colab.research.google.com/drive/1QX4yH3zRFF3Xiibf1aahETcSQ5nbcUMz). Currently to set up FSDP, we do:

model = hax.shard(model, {"embed": "data"})

and this means that every "embed" axis in the model is sharded across the data axis of the device mesh. To add tensor parallelism, you'd do something like:

model = hax.shard(model, {"embed": "data", "mlp": "model"})

With your change, we'd have to do

model = hax.shard(model, {"embed_in": "data", "embed_out": "data"})

which seems noisier. WDYT?

cooljoseph1 commented 4 weeks ago

They seem to be pretty much the same noisy to me, and I think it's fine to make that change. In the first one you need to have separate names for all your axes in a sequence of linear layers, which can be just as confusing.

I think it ultimately comes down to needing a disjoint union of axes specs, not a union, and I don't think this is possible without renaming things.

Perhaps one could create some kind of tree (or DAG) of axes that are derived from other axes and then automagically when sharding also shard any sub-axes, but that feels like overcomplicating things.