Open radna0 opened 1 month ago
take a look at spmd docs at https://pytorch.org/xla/release/r2.4/spmd.html.
If you want to do ddp with spmd, try https://github.com/pytorch/xla/blob/master/examples/data_parallel/train_resnet_spmd_data_parallel.py
If you want to do FSDP with spmd, try https://github.com/pytorch/xla/blob/master/examples/fsdp/train_decoder_only_fsdp_v2.py
❓ Questions and Help
I'm running this official script here, but I only see two xla devices being used, xla:0 and xla:1, which I read that for TPU V2-8 and V3-8 Instance, you would have 4 processor for one worker? But does that mean I can not fully use all memory of all devices? Given this code here, there should be around 128gb in total memory?
I want to use all of this memory capacity with SPMD for example to shard and train a video diffusion transformer model, which requires sharding a tensor of 5 dimensions (b, c, f, h, w). Given that SPMD abstracts the complexity of coding using multiple devices and acts as one big device, it would be perfect for the job?