Open vyeevani opened 1 year ago
Very cool! Do you have a link to the code that you're using to eagerly fetch the data? Would also be interested in hacking around something here for pre-fetching iterable datasets
I ended up just switching back to the pytorch dataloader and using it's multiprocessing functionality to handle this :(. I'm just not that familiar with python multiprocessing to get something to work in jupyter (kept having weird behaviors happening with zombies living after the cell finished).
Ultimately settled on using webdataset to circumvent huggingface datasets entirely. Would definitely switch back if: https://github.com/huggingface/datasets/issues/5337 was resolved.
Hi! You can combine datasets
with torchdata
to prefetch IterableDataset
's samples:
from datasets import load_dataset
from torchdata.datapipes.iter import IterableWrapper, HuggingFaceHubReader
from torch.utils.data import DataLoader
ds = load_dataset("sst", split="train", streaming=True)
# processing...
dp = IterableWrapper(ds)
dp = dp.prefetch(100)
dl = DataLoader(dp, batch_size=8)
i = iter(dl)
next(i)
Hey @mariosasko! Thanks for the tip here - introducing prefetch with torchdata
didn't really give me any performance difference vs not prefetching, but the concept is definitely one that could be really beneficial. Are there any benchmarks that show the speed-up you can get with torchdata
's prefetch just for comparison?
Feature request
Add support for prefetching the next n batches through iterabledataset to reduce batch loading bottleneck in training loop.
Motivation
The primary motivation behind this is to use hardware accelerators alongside a streaming dataset. This is required when you are in a low ram or low disk space setting as well as quick iteration where you're iterating though different accelerator environments (e.x changing ec2 instances quickly to figure out batch/sec for a particular architecture).
Currently, using the IterableDataset results in accelerators becoming basically useless due to the massive bottleneck induced by the dataset lazy loading/transform/mapping.
I've considered two alternatives: PyTorch dataloader that handles this. However, I'm using jax, and I believe this is a piece of functionality that should live in the stream class.
Replicating the "num_workers" part of the PyTorch DataLoader to eagerly load batches and apply the transform so Arrow caching will automatically cache results and make them accessible.
Your contribution
I may or may not have time to do this. Currently, I've written the basic multiprocessor approach to handle the eager DataLoader for my own use case with code that's not integrated to datasets. I'd definitely see this as being the default over the regular Dataset for most people given that they wouldn't have to wait on the datasets while also not worrying about performance.