huggingface / datasets

🤗 The largest hub of ready-to-use datasets for ML models with fast, easy-to-use and efficient data manipulation tools
https://huggingface.co/docs/datasets
Apache License 2.0
18.73k stars 2.58k forks source link

Batched IterableDataset #6279

Open lneukom opened 9 months ago

lneukom commented 9 months ago

Feature request

Hi,

could you add an implementation of a batched IterableDataset. It already support an option to do batch iteration via .iter(batch_size=...) but this cannot be used in combination with a torch DataLoader since it just returns an iterator.

Motivation

The current implementation loads each element of a batch individually which can be very slow in cases of a big batch_size. I did some experiments here and using a batched iteration would speed up data loading significantly.

Your contribution

N/A

VascoSch92 commented 9 months ago

This is exactly what I was looking for. It would also be very useful for me :-)

zhh210 commented 1 month ago

This issue is really smashing the selling point of HF datasets... The only workaround I've found so far is to create a customized IterableDataloader which improves the loading speed to some extent.

For example I've a HF dataset dt_train with len(dt_train) == 1M. Using plain DataLoader is extremely slow:

%%time
dl_train = DataLoader(dt_train, batch_size=128, shuffle = True)
for batch in dl_train:
    pass
CPU times: user 24min 35s, sys: 704 ms, total: 24min 36s
Wall time: 24min 37s

And DataLoader works even worse with HF's iterable_dataset:

%%time
dt_train_ = dt_train.with_format(None).to_iterable_dataset(num_shards=64).shuffle(buffer_size=10_000)
dl_train = DataLoader(dt_train_, batch_size=128)
for batch in dl_train:
    pass
CPU times: user 1h 6min 2s, sys: 4.28 s, total: 1h 6min 6s
Wall time: 1h 7min 53s

Workaround by running a customized wrapper:

%%time
from torch.utils.data import DataLoader, IterableDataset

class Dataset2Iterable(IterableDataset):
    """
    Wrapper to use a HF dataset as pytorch IterableDataset to speed up data loading.
    """
    def __init__(self, dataset, batch_size=1, shuffle=True):
        super(Dataset2Iterable).__init__()
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __iter__(self):
        if self.shuffle: self.dataset.shuffle()
        return self.dataset.iter(batch_size=self.batch_size)

dl_train = DataLoader(Dataset2Iterable(dt_train, batch_size = 128), batch_size=1, num_workers=0)
for n in range(2):
    for batch in dl_train:
        pass

The speed still is slower than using tensorflow's loader but improved a lot than previous code:

CPU times: user 4min 18s, sys: 0 ns, total: 4min 18s
Wall time: 4min 20s

Note that the way I implemented Dataset2Iterable will only work with num_workers=0.

jaketae commented 3 days ago

I can confirm that @zhh210's solution works with num_workers=0. However, for my use case, this was still slower than tokenizing on the fly through a collator and leveraging multiple workers in the dataloder.

@lhoestq I think this is an important use case (e.g., streaming from a large dataset, online or stored on disk). What do you think might be the best solution to move forward?

lhoestq commented 4 hours ago

I guess it can be implemented using a batched.map() under the hood that returns a single item containing the input batch.

In the meantime you can use this:

def batch(unbatched: dict[str, list]) -> dict[str, list]:
    return {k: [v] for k, v in unbatched}

batched_dataset = dataset.map(batch, batched=True, batch_size=batch_size)

Though it would be great to have a .batch() method indeed, I'd be happy to help with anyone wants to open a PR

lappemic commented 3 hours ago

If no one else is planning to work on this, I can take it on. I'll wait until next week, and if no one has started a PR by then, I'll go ahead and open one.