awslabs / s3-connector-for-pytorch

The Amazon S3 Connector for PyTorch delivers high throughput for PyTorch training jobs that access and store data in Amazon S3.
BSD 3-Clause "New" or "Revised" License
111 stars 17 forks source link

Error when importing laion-art dataset #102

Closed awsankur closed 8 months ago

awsankur commented 8 months ago

s3torchconnector version

1.1.0

s3torchconnectorclient version

1.1.0

AWS Region

us-west-2

Describe the running environment

Running in EC2 m5.8xlarge Amazon Linux 2

What happened?

I am following the example notebook here: https://github.com/awslabs/s3-connector-for-pytorch/blob/main/examples/Getting%20s[…]ed%20with%20the%20Amazon%20S3%20Connector%20for%20PyTorch.ipynb

I need to modify it to load a dataset where multiple images are sharded in hundreds of .tar files. An example dataset is the laion-art dataset. I am using the following code:

import s3torchconnector

from PIL import Image
import webdataset
import itertools
import io
import torch
import torchvision

def shard_to_dict(object):
    return {"url": object.key, "stream": object}

IMAGES_URI = "s3://laion-art/laion-art-data/tarfiles_reorganized/task0000/"

REGION = "us-west-2"

s3_dataset = s3torchconnector.S3MapDataset.from_prefix(IMAGES_URI, region=REGION, transform=shard_to_dict)

dataset = webdataset.tariterators.tar_file_expander(s3_dataset)

loader = torch.utils.data.DataLoader(dataset, batch_size=4)
for batch in loader:
    print(batch)

But I get the error: TypeError: object of type 'generator' has no len()

Relevant log output

(pt-nightlies) ubuntu@ip-10-0-55-158:/apps$ python3 s3connector.py
Traceback (most recent call last):
  File "/apps/s3connector.py", line 23, in <module>
    for batch in loader:
  File "/apps/.conda/envs/pt-nightlies/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
    data = self._next_data()
  File "/apps/.conda/envs/pt-nightlies/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 674, in _next_data
    index = self._next_index()  # may raise StopIteration
  File "/apps/.conda/envs/pt-nightlies/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 621, in _next_index
    return next(self._sampler_iter)  # may raise StopIteration
  File "/apps/.conda/envs/pt-nightlies/lib/python3.10/site-packages/torch/utils/data/sampler.py", line 287, in __iter__
    for idx in self.sampler:
  File "/apps/.conda/envs/pt-nightlies/lib/python3.10/site-packages/torch/utils/data/sampler.py", line 111, in __iter__
    return iter(range(len(self.data_source)))
TypeError: object of type 'generator' has no len()

Code of Conduct

muddyfish commented 8 months ago

Hi @awsankur, thanks for your report!

The issue you are running into is caused by the webdataset.tariterators.tar_file_expander function returning an iterable rather than a Dataset that the DataLoader constructor requires.

The external torchdata package we use in the example provides an IterableWrapper which can be used to fix the issue:

...
s3_dataset = s3torchconnector.S3IterableDataset.from_prefix(IMAGES_URI, region=REGION, transform=shard_to_dict)
tar_dataset = webdataset.tariterators.tar_file_expander(s3_dataset)
dataset = torchdata.datapipes.iter.IterableWrapper(tar_dataset)

loader = torch.utils.data.DataLoader(dataset, batch_size=4)
...

Note that this uses an S3IterableDataset to start with rather than a S3MapDataset, as by converting with IterableWrapper you lose all the advantages S3MapDataset provides even though both will function.

Please let us know if this resolves your issue.

awsankur commented 8 months ago

Thanks. It works

muddyfish commented 8 months ago

We've updated the example file to include a more concrete demonstration of how to use the tar file expander with dataloaders.