huggingface / datasets

🤗 The largest hub of ready-to-use datasets for ML models with fast, easy-to-use and efficient data manipulation tools
https://huggingface.co/docs/datasets
Apache License 2.0
19.24k stars 2.69k forks source link

Prefetching for IterableDataset #5878

Open vyeevani opened 1 year ago

vyeevani commented 1 year ago

Feature request

Add support for prefetching the next n batches through iterabledataset to reduce batch loading bottleneck in training loop.

Motivation

The primary motivation behind this is to use hardware accelerators alongside a streaming dataset. This is required when you are in a low ram or low disk space setting as well as quick iteration where you're iterating though different accelerator environments (e.x changing ec2 instances quickly to figure out batch/sec for a particular architecture).

Currently, using the IterableDataset results in accelerators becoming basically useless due to the massive bottleneck induced by the dataset lazy loading/transform/mapping.

I've considered two alternatives: PyTorch dataloader that handles this. However, I'm using jax, and I believe this is a piece of functionality that should live in the stream class.

Replicating the "num_workers" part of the PyTorch DataLoader to eagerly load batches and apply the transform so Arrow caching will automatically cache results and make them accessible.

Your contribution

I may or may not have time to do this. Currently, I've written the basic multiprocessor approach to handle the eager DataLoader for my own use case with code that's not integrated to datasets. I'd definitely see this as being the default over the regular Dataset for most people given that they wouldn't have to wait on the datasets while also not worrying about performance.

sanchit-gandhi commented 1 year ago

Very cool! Do you have a link to the code that you're using to eagerly fetch the data? Would also be interested in hacking around something here for pre-fetching iterable datasets

vyeevani commented 1 year ago

I ended up just switching back to the pytorch dataloader and using it's multiprocessing functionality to handle this :(. I'm just not that familiar with python multiprocessing to get something to work in jupyter (kept having weird behaviors happening with zombies living after the cell finished).

vyeevani commented 1 year ago

Ultimately settled on using webdataset to circumvent huggingface datasets entirely. Would definitely switch back if: https://github.com/huggingface/datasets/issues/5337 was resolved.

mariosasko commented 1 year ago

Hi! You can combine datasets with torchdata to prefetch IterableDataset's samples:

from datasets import load_dataset
from torchdata.datapipes.iter import IterableWrapper, HuggingFaceHubReader
from torch.utils.data import DataLoader

ds = load_dataset("sst", split="train", streaming=True)
# processing...
dp = IterableWrapper(ds)
dp = dp.prefetch(100)
dl = DataLoader(dp, batch_size=8)

i = iter(dl)
next(i)
sanchit-gandhi commented 1 year ago

Hey @mariosasko! Thanks for the tip here - introducing prefetch with torchdata didn't really give me any performance difference vs not prefetching, but the concept is definitely one that could be really beneficial. Are there any benchmarks that show the speed-up you can get with torchdata's prefetch just for comparison?