huggingface / datasets

🤗 The largest hub of ready-to-use datasets for ML models with fast, easy-to-use and efficient data manipulation tools
https://huggingface.co/docs/datasets
Apache License 2.0
18.97k stars 2.62k forks source link

Super slow iteration with trivial custom transform #6833

Open xslittlegrass opened 4 months ago

xslittlegrass commented 4 months ago

Describe the bug

Dataset is 10X slower when applying trivial transforms:

import time
import numpy as np
from datasets import Dataset, Features, Array2D

a = np.zeros((800, 800))
a = np.stack([a] * 1000)
features = Features({"a": Array2D(shape=(800, 800), dtype="uint8")})

ds1 = Dataset.from_dict({"a": a}, features=features).with_format('numpy')

def transform(batch):
    return batch

ds2 = ds1.with_transform(transform)

%time sum(1 for _ in ds1)
%time sum(1 for _ in ds2)
CPU times: user 472 ms, sys: 319 ms, total: 791 ms
Wall time: 794 ms
CPU times: user 9.32 s, sys: 443 ms, total: 9.76 s
Wall time: 9.78 s

In my real code I'm using set_transform to apply some post-processing on-the-fly for the 2d array, but it significantly slows down the dataset even if the transform itself is trivial.

Related issue: https://github.com/huggingface/datasets/issues/5841

Steps to reproduce the bug

Use code in the description to reproduce.

Expected behavior

Trivial custom transform in the example should not slowdown the dataset iteration.

Environment info

rangehow commented 4 months ago

Similar issue in text process


tokenizer=AutoTokenizer.from_pretrained(model_dir[args.model])
train_dataset=datasets.load_from_disk(dataset_dir[args.dataset],keep_in_memory=True)['train']
train_dataset=train_dataset.map(partial(dname2func[args.dataset],tokenizer=tokenizer),batched=True,num_proc =50,remove_columns=train_dataset.features.keys(),desc='tokenize',keep_in_memory=True)

After this train_dataset will be like

Dataset({
    features: ['input_ids', 'labels'],
    num_rows: 51760
})

In which input_ids and labels are both List[int] However, per iter on dataset cost 7.412479639053345s ……?

for j in tqdm(range(len(train_dataset)),desc='first stage'):
    input_id,label=train_dataset['input_ids'][j],train_dataset['labels'][j]
lhoestq commented 4 months ago

The transform currently replaces the numpy formatting.

So you're back to copying data to long python lists which is super slow.

It would be cool for the transform to not remove the formatting in this case, but this requires a few changes in the lib