Open lneukom opened 9 months ago
This is exactly what I was looking for. It would also be very useful for me :-)
This issue is really smashing the selling point of HF datasets... The only workaround I've found so far is to create a customized IterableDataloader which improves the loading speed to some extent.
For example I've a HF dataset dt_train
with len(dt_train) == 1M
. Using plain DataLoader is extremely slow:
%%time
dl_train = DataLoader(dt_train, batch_size=128, shuffle = True)
for batch in dl_train:
pass
CPU times: user 24min 35s, sys: 704 ms, total: 24min 36s
Wall time: 24min 37s
And DataLoader works even worse with HF's iterable_dataset:
%%time
dt_train_ = dt_train.with_format(None).to_iterable_dataset(num_shards=64).shuffle(buffer_size=10_000)
dl_train = DataLoader(dt_train_, batch_size=128)
for batch in dl_train:
pass
CPU times: user 1h 6min 2s, sys: 4.28 s, total: 1h 6min 6s
Wall time: 1h 7min 53s
Workaround by running a customized wrapper:
%%time
from torch.utils.data import DataLoader, IterableDataset
class Dataset2Iterable(IterableDataset):
"""
Wrapper to use a HF dataset as pytorch IterableDataset to speed up data loading.
"""
def __init__(self, dataset, batch_size=1, shuffle=True):
super(Dataset2Iterable).__init__()
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
def __iter__(self):
if self.shuffle: self.dataset.shuffle()
return self.dataset.iter(batch_size=self.batch_size)
dl_train = DataLoader(Dataset2Iterable(dt_train, batch_size = 128), batch_size=1, num_workers=0)
for n in range(2):
for batch in dl_train:
pass
The speed still is slower than using tensorflow's loader but improved a lot than previous code:
CPU times: user 4min 18s, sys: 0 ns, total: 4min 18s
Wall time: 4min 20s
Note that the way I implemented Dataset2Iterable
will only work with num_workers=0
.
I can confirm that @zhh210's solution works with num_workers=0
. However, for my use case, this was still slower than tokenizing on the fly through a collator and leveraging multiple workers in the dataloder.
@lhoestq I think this is an important use case (e.g., streaming from a large dataset, online or stored on disk). What do you think might be the best solution to move forward?
I guess it can be implemented using a batched.map()
under the hood that returns a single item containing the input batch.
In the meantime you can use this:
def batch(unbatched: dict[str, list]) -> dict[str, list]:
return {k: [v] for k, v in unbatched}
batched_dataset = dataset.map(batch, batched=True, batch_size=batch_size)
Though it would be great to have a .batch()
method indeed, I'd be happy to help with anyone wants to open a PR
If no one else is planning to work on this, I can take it on. I'll wait until next week, and if no one has started a PR by then, I'll go ahead and open one.
Feature request
Hi,
could you add an implementation of a batched
IterableDataset
. It already support an option to do batch iteration via.iter(batch_size=...)
but this cannot be used in combination with a torchDataLoader
since it just returns an iterator.Motivation
The current implementation loads each element of a batch individually which can be very slow in cases of a big batch_size. I did some experiments here and using a batched iteration would speed up data loading significantly.
Your contribution
N/A