awslabs / s3-connector-for-pytorch

The Amazon S3 Connector for PyTorch delivers high throughput for PyTorch training jobs that access and store data in Amazon S3.
BSD 3-Clause "New" or "Revised" License
111 stars 17 forks source link

Support multi-process/multi-node sharding for `S3IterableDataset` #53

Open jamesbornholt opened 10 months ago

jamesbornholt commented 10 months ago

We currently don't have a built in way to do sharding for S3IterableDataset, so every worker process in a DataLoader will see the same stream of objects. We should have a way to do this.

In the meantime, something like this from torchdata will work as a workaround:

from s3torchconnector import S3IterableDataset
from torch.utils.data import DataLoader
from torch.utils.data.datapipes.iter import IterableWrapper

dataset = S3IterableDataset.from_prefix("s3://doc-example-bucket/", region="us-west-2")
dataset = IterableWrapper(dataset, deepcopy=False)
dataset = dataset.sharding_filter() # Use torchdata's sharding for iterable datasets
loader = DataLoader(dataset, num_workers=2)
cfregly commented 6 months ago

Related pull request for Megatron: https://github.com/NVIDIA/Megatron-LM/pull/729

jamesbornholt commented 3 months ago

The torchdata IterableWrapper is being deprecated in a future release, but it will still be present in PyTorch core. I've updated the code example above to point to that instead.