Open dudulightricks opened 3 months ago
@alanwaketan can you take this one?
So, you want one vm to load all the data and distributed it to other devices?
@alanwaketan Exactly the opposite. We want to load the data using all the VMs (dataloader on each one), but still using SPMD.
You can just use the data loader following: https://github.com/pytorch/xla/blob/master/docs/spmd_advanced.md#sharding-aware-host-to-device-data-loading
Do you have any concerns on following the tutorial?
We have tried using "input_sharding=xs.ShardingSpec" if thats what you mean but still it just use the batch from the dataloader of the first VM and ignore all the others. We need a way to use all the data loaded in all the VMs and shard it without holding it all in one VM @alanwaketan
We have tried using "input_sharding=xs.ShardingSpec" if thats what you mean but still it just use the batch from the dataloader of the first VM and ignore all the others. We need a way to use all the data loaded in all the VMs and shard it without holding it all in one VM @alanwaketan
The default setting in SPMD is that all VM will load the same batch of data and then the dataloader will select corresponding shard and send that to the device.
so how do I change it? I want to use all the VMs to load data and then work on it sharded on all the TPUs
all you need is to is pass the sharding to the loader similar to https://github.com/pytorch/xla/blob/5f82da90e744b9c8da8690a0f4cc269f7fa474c9/examples/data_parallel/train_resnet_spmd_data_parallel.py#L40-L44
each host will load all of the batches into the host RAM, but only the portion of the data that belong this host will be loaded from the RAM to the device memory. I think the last slide of my video https://www.youtube.com/watch?v=cP5zncbNTEk explained this.
❓ Questions and Help
While in SPMD mode If we run the train command of a model on all the VMs together (single program multiple machines) each VM has its own dataloader using cpu cores. Then, when we use mark_sharding on the batch its practically copy the batch of the first VM (rank 0) to all the TPUs and ignore the batches of other VMs (which were loaded with different dataloaders). In order to solve that (use all the dataloaders on the different VMs to load different data and use it all) we have added torch.distributed.all_gather_object on the batch object to get one huge batch before using mark_sharding. The problem is that in this case we afraid that the huge batch is held in the memory of one VM before the sharding. The ideal solution for us would have been something like batch.mark_sharding(gather_all=True) in which instead of ignoring the different batches on all the VMs it will gather them all together logically and use mark_sharding on the result huge batch (which is practically splited over the TPUs). This way we will use all the loaded data without exploding the memory of the first VM. Is there anything like that command? How can we use the data loaded in all the dataloaders on the different VMs? In our case its important because the data is large and it takes time to load it. @JackCaoG