huggingface / datasets

🤗 The largest hub of ready-to-use datasets for ML models with fast, easy-to-use and efficient data manipulation tools
https://huggingface.co/docs/datasets
Apache License 2.0
19.18k stars 2.68k forks source link

Filter on dataset too much slowww #1796

Open ayubSubhaniya opened 3 years ago

ayubSubhaniya commented 3 years ago

I have a dataset with 50M rows. For pre-processing, I need to tokenize this and filter rows with the large sequence.

My tokenization took roughly 12mins. I used map() with batch size 1024 and multi-process with 96 processes.

When I applied the filter() function it is taking too much time. I need to filter sequences based on a boolean column. Below are the variants I tried.

  1. filter() with batch size 1024, single process (takes roughly 3 hr)
  2. filter() with batch size 1024, 96 processes (takes 5-6 hrs ¯\_(ツ)_/¯)
  3. filter() with loading all data in memory, only a single boolean column (never ends).

Can someone please help?

Below is a sample code for small dataset.

from datasets import load_dataset
dataset = load_dataset('glue', 'mrpc', split='train')
dataset = dataset.map(lambda x: {'flag': random.randint(0,1)==1})

def _amplify(data):
        return data

dataset = dataset.filter(_amplify, batch_size=1024, keep_in_memory=False, input_columns=['flag'])
ayubSubhaniya commented 3 years ago

When I use the filter on the arrow table directly, it works like butter. But I can't find a way to update the table in Dataset object.

ds_table = dataset.data.filter(mask=dataset['flag'])
ayubSubhaniya commented 3 years ago

@thomwolf @lhoestq can you guys please take a look and recommend some solution.

lhoestq commented 3 years ago

Hi ! Currently the filter method reads the dataset batch by batch to write a new, filtered, arrow file on disk. Therefore all the reading + writing can take some time. Using a mask directly on the arrow table doesn't do any read or write operation therefore it's way quicker.

Replacing the old table by the new one should do the job:

dataset._data = dataset._data.filter(...)

Note: this is a workaround and in general users shouldn't have to do that. In particular if you did some shuffle or select before that then it would not work correctly since the indices mapping (index from __getitem__ -> index in the table) would not be valid anymore. But if you haven't done any shuffle, select, shard, train_test_split etc. then it should work.

Ideally it would be awesome to update the filter function to allow masking this way ! If you would like to give it a shot I will be happy to help :)

ayubSubhaniya commented 3 years ago

Yes, would be happy to contribute. Thanks

gchhablani commented 3 years ago

Hi @lhoestq @ayubSubhaniya,

If there's no progress on this one, can I try working on it?

Thanks, Gunjan

lhoestq commented 3 years ago

Sure @gchhablani feel free to start working on it, this would be very appreciated :) This feature is would be really awesome, especially since arrow allows to mask really quickly and without having to rewrite the dataset on disk

susnato commented 9 months ago

Hi @lhoestq, any updates on this issue? The filter method is still veryyy slow 😕

lhoestq commented 9 months ago

No update so far, we haven't worked on this yet :/

Though PyArrow is much more stable than 3 years ago so it would be a good time to dive into this

susnato commented 9 months ago

Hi @lhoestq, thanks a lot for the update!

I would like to work on this(if possible). Could you please give me some steps regarding how should I approach this? Also any references would be great!

lhoestq commented 9 months ago

I just played a bit with it to make sure using table.filter() is fine, but actually it seems to create a new table in memory :/ This is an issue since it can quickly fill the RAM, and datasets's role is to make sure you can load bigger-than-memory datasets. Therefore I don't think it's a good idea in the end to use table.filter()

Anyway I just ran OP's code an it runs in 20ms now on my side thanks to the I/O optimizations we did.

Another way to speed up filter is to add support pyarrow expressions though, using e.g. arrow formatting + dataset.filter (runs in 10ms on my side):

import pyarrow.dataset as pds
import pyarrow.compute as pc

expr = pc.field("flag") == True

filtered = dataset.with_format("arrow").filter(
    lambda t: pds.dataset(t).to_table(columns={"mask": expr})[0].to_numpy(),
    batched=True,
).with_format(None)