huggingface / datasets

🤗 The largest hub of ready-to-use datasets for ML models with fast, easy-to-use and efficient data manipulation tools
https://huggingface.co/docs/datasets
Apache License 2.0
18.97k stars 2.62k forks source link

`sort` after `filter` unreasonably slow #7041

Open Tobin-rgb opened 1 month ago

Tobin-rgb commented 1 month ago

Describe the bug

as the tittle says ...

Steps to reproduce the bug

sort seems to be normal.

from datasets import Dataset
import random

nums = [{"k":random.choice(range(0,1000))} for _ in range(100000)]

ds = Dataset.from_list(nums)

print("start sort")

ds = ds.sort("k")

print("finish sort")

but sort after filter is extremely slow.

from datasets import Dataset
import random

nums = [{"k":random.choice(range(0,1000))} for _ in range(100000)]

ds = Dataset.from_list(nums)

ds = ds.filter(lambda x:x > 100, input_columns="k")

print("start sort")

ds = ds.sort("k")

print("finish sort")

Expected behavior

Is this a bug, or is it a misuse of the sort function?

Environment info

lhoestq commented 1 month ago

filter add an indices mapping on top of the dataset, so sort has to gather all the rows that are kept to form a new Arrow table and sort the table. Gathering all the rows can take some time, but is a necessary step. You can try calling ds = ds.flatten_indices() before sorting to remove the indices mapping.