patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.13k stars 143 forks source link

PositionalSharding tensor of different ranks: use sharding.replicate! #815

Closed davidaknowles closed 2 months ago

davidaknowles commented 3 months ago

I'm trying to setup sharding on a TPU VM. I have data of different ranks, let's say x is rank-3 and y is rank-2. When I try to adapt the parallelism example I get an error saying ValueError: One of with_sharding_constraint arguments is incompatible with its sharding annotation ... is only valid for values of rank 2, but was applied to a value of rank 3. I can change the create_device_mesh setup call to

devices = mesh_utils.create_device_mesh((num_devices, 1, 1))

and then it's OK with my rank-3 tensor but not the rank-2 one!

OK I just figured this out in the process of writing this. You use sharding.replicate() (as in train_step) to make a sharding that can handle additional dimensions. The (jax) documentation for this is very limited! Maybe worth adding a line to the example to explain this.

Thanks for all the great work on equinox, I'm on day 1.5 of trying out jax coming from pytorch and it's been very helpful.

davidaknowles commented 3 months ago

Follow-up question: in the example, the data (and for that matter the model) are sharded both in the training loop and in train_step. Is that correct? My code seems to run fine with just doing the sharding of the data in the training loop but not in train_step.

patrick-kidger commented 3 months ago

Glad you figured it out!

On sharding in each location: yup, that's totally fine. In the example, the sharding outside the training loop is what's actually doing the sharding, and the one inside train_step corresponds to an assertion that things are sharded in the expected manner.

davidaknowles commented 2 months ago

Thanks!