pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.47k stars 471 forks source link

How to use spmd to support hybrid shard data parallelism? #7607

Open mars1248 opened 3 months ago

mars1248 commented 3 months ago

❓ Questions and Help

Fsdp can be well expressed by spmd, but hsdp seems to be unable to be expressed. Is there any way to express hsdp in spmd?

JackCaoG commented 3 months ago

my understanding is that HSDP just FSDP + shard the data? Let me know what I missed here.

mars1248 commented 3 months ago

You have a point, but how to implement FSDP+shard data? How would you represent this in spmd's mark_sharding? Or is there any other way to represent it

JackCaoG commented 3 months ago

take a look at https://github.com/pytorch/xla/blob/master/examples/fsdp/train_decoder_only_fsdp_v2.py, you just need to pass the sharding to the parallel loader which will shard the data for you.