pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

SPMD - how to use different dataloader on each VM of a TPU pod in SPMD #7850

Open dudulightricks opened 3 months ago

dudulightricks commented 3 months ago

❓ Questions and Help

While in SPMD mode If we run the train command of a model on all the VMs together (single program multiple machines) each VM has its own dataloader using cpu cores. Then, when we use mark_sharding on the batch its practically copy the batch of the first VM (rank 0) to all the TPUs and ignore the batches of other VMs (which were loaded with different dataloaders). In order to solve that (use all the dataloaders on the different VMs to load different data and use it all) we have added torch.distributed.all_gather_object on the batch object to get one huge batch before using mark_sharding. The problem is that in this case we afraid that the huge batch is held in the memory of one VM before the sharding. The ideal solution for us would have been something like batch.mark_sharding(gather_all=True) in which instead of ignoring the different batches on all the VMs it will gather them all together logically and use mark_sharding on the result huge batch (which is practically splited over the TPUs). This way we will use all the loaded data without exploding the memory of the first VM. Is there anything like that command? How can we use the data loaded in all the dataloaders on the different VMs? In our case its important because the data is large and it takes time to load it. @JackCaoG

JackCaoG commented 3 months ago

@alanwaketan can you take this one?

alanwaketan commented 2 months ago

So, you want one vm to load all the data and distributed it to other devices?

dudulightricks commented 2 months ago

@alanwaketan Exactly the opposite. We want to load the data using all the VMs (dataloader on each one), but still using SPMD.

alanwaketan commented 2 months ago

You can just use the data loader following: https://github.com/pytorch/xla/blob/master/docs/spmd_advanced.md#sharding-aware-host-to-device-data-loading

Do you have any concerns on following the tutorial?

dudulightricks commented 2 months ago

We have tried using "input_sharding=xs.ShardingSpec" if thats what you mean but still it just use the batch from the dataloader of the first VM and ignore all the others. We need a way to use all the data loaded in all the VMs and shard it without holding it all in one VM @alanwaketan

alanwaketan commented 2 months ago

We have tried using "input_sharding=xs.ShardingSpec" if thats what you mean but still it just use the batch from the dataloader of the first VM and ignore all the others. We need a way to use all the data loaded in all the VMs and shard it without holding it all in one VM @alanwaketan

The default setting in SPMD is that all VM will load the same batch of data and then the dataloader will select corresponding shard and send that to the device.

dudulightricks commented 2 months ago

so how do I change it? I want to use all the VMs to load data and then work on it sharded on all the TPUs

JackCaoG commented 2 months ago

all you need is to is pass the sharding to the loader similar to https://github.com/pytorch/xla/blob/5f82da90e744b9c8da8690a0f4cc269f7fa474c9/examples/data_parallel/train_resnet_spmd_data_parallel.py#L40-L44

each host will load all of the batches into the host RAM, but only the portion of the data that belong this host will be loaded from the RAM to the device memory. I think the last slide of my video https://www.youtube.com/watch?v=cP5zncbNTEk explained this.