Closed davidaknowles closed 2 months ago
Follow-up question: in the example, the data (and for that matter the model) are sharded both in the training loop and in train_step
. Is that correct? My code seems to run fine with just doing the sharding of the data in the training loop but not in train_step
.
Glad you figured it out!
On sharding in each location: yup, that's totally fine. In the example, the sharding outside the training loop is what's actually doing the sharding, and the one inside train_step
corresponds to an assertion that things are sharded in the expected manner.
Thanks!
I'm trying to setup sharding on a TPU VM. I have data of different ranks, let's say x is rank-3 and y is rank-2. When I try to adapt the parallelism example I get an error saying
ValueError: One of with_sharding_constraint arguments is incompatible with its sharding annotation ... is only valid for values of rank 2, but was applied to a value of rank 3
. I can change thecreate_device_mesh
setup call toand then it's OK with my rank-3 tensor but not the rank-2 one!
OK I just figured this out in the process of writing this. You use
sharding.replicate()
(as intrain_step
) to make a sharding that can handle additional dimensions. The (jax
) documentation for this is very limited! Maybe worth adding a line to the example to explain this.Thanks for all the great work on
equinox
, I'm on day 1.5 of trying outjax
coming frompytorch
and it's been very helpful.