Open mars1248 opened 3 months ago
my understanding is that HSDP just FSDP + shard the data? Let me know what I missed here.
You have a point, but how to implement FSDP+shard data? How would you represent this in spmd's mark_sharding? Or is there any other way to represent it
take a look at https://github.com/pytorch/xla/blob/master/examples/fsdp/train_decoder_only_fsdp_v2.py, you just need to pass the sharding to the parallel loader which will shard the data for you.
❓ Questions and Help
Fsdp can be well expressed by spmd, but hsdp seems to be unable to be expressed. Is there any way to express hsdp in spmd?