patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.07k stars 136 forks source link

`filter_shard` test crashes on JAX 0.4.29 #755

Open patrick-kidger opened 4 months ago

patrick-kidger commented 4 months ago

Looks like JAX used to do some "broadcasting" here, and no longer does. Bearing in mind that a PyTree may have arrays of multiple ranks, I'm not immediately sure what the appropriate fix is.

Tagging @homerjed, who originally contributed this; also @dlwh for visibility. What do you think?

homerjed commented 4 months ago

Ah this is interesting. I will have a look at this but not sure a priori!

nasyxx commented 3 months ago

Is it possible to align the rank by manual reshaping, put it to sharding, and then reshape it back to the original shape as a repair?

dlwh commented 3 months ago

We're wrestling with sharding over in Levanter land right now, especially in the context of changing meshes around and such, which isn't exactly what you're talking about.

Because of the named axis stuff in Haliax the broadcasting thing you're talking about shouldn't come up since we compute a unique sharding for every named array (not shilling named axes, really, just excusing my lack of insight here)

I do think that in the general case, a model can't really be sharded usefully with broadcasting (since the FSDP axis and the TP axis aren't always in the same position), so it's not really a permanent solution. Data is a bit clearer where you almost always want to shard all arrays along the 0th axis, except for really fancy workloads.

So maybe just have a helper/special case for "shard every array along the 0th axis" and require anything else to be fully specified?

patrick-kidger commented 3 months ago

Thanks for the input folks! (In particular for the great write up David.)

It sounds like the general answer to this is really an "its complicated", and fundamentally this is a topic that isn't a core competency for Equinox. Properly sharding multiply-ranked tensors is the kind of thing where a library like Haliax's named axes really start to shine instead!

@ChenAo-Phys has done a great job fixing the test over as part of #765. So I think let's keep eqx.filter_shard as a thin wrapper around lax.with_sharding_constraint, and consider the incompatibility here a JAX-level thing that we don't try to engineer around. (We'll let @dlwh think very hard about this instead ;) )