Open ssharpe42 opened 8 months ago
Hi, @ssharpe42, thanks for filing this feature request. Is your dataset saved locally (Built on top of a map-style dataset) for you to filter the dataset on the fly? Or are you using streaming mode in the HuggingFace dataset, built on top of IterableDataset
?
For now, you would have to create the MDS dataset with filter logic and then use the dataset mixing (checkout class Stream
) to mix the dataset on the fly. For example, create an MDS dataset for sequence_len<10
, create another MDS dataset for sequence_len<10,>20
, another for sequence_len>20,<30
and so on. After that, instantiate one Stream for each MDS dataset and pass the sequence of Stream
dataset to StreamingDataset.
Yes, we are currently using HF dataset in streaming mode since we have datasets that are 200GB on disk in parquet files. Trying to find a way to more efficiently iterate through them when we have many GPU nodes as currently the bottleneck seems to be the dataloader workers having to iterate through the dataset in each process, skipping samples that are reserved for other ranks.
In general, we want the ability to arbitrarily drop samples (filter out) based on conditions as we iterate through. Not sure if all of these needs are feasible together, but open to other ideas how to solve it! So I guess based on what you are saying the best way to go about it now is to precompute MDS format datasets given our desired filters.
@ssharpe42 Sorry for my late response. For now, precompute MDS format datasets given our desired filters
is definitely one of the ways.
@knighton Do you have any suggestions here?
🚀 Feature Request
I would like to be able to filter my data based on a callable that returns true/false and limit to a specific length like huggingface datasets (https://huggingface.co/docs/datasets/process#select-and-filter)
Motivation
I want to be able to experiment with different parts of my dataset without creating a whole new MDS dataset for each scenario. I am working with sequence data (not text) and would like to use a filter like this for example to only keep sequences of length > 10
[Optional] Implementation
I am working with sequence data (not text) and would like to use a filter like this for example to only keep sequences of length > 10
new_dataset = dataset.filter(lambda x: len(x["sequence"]) > 10)
Or test on the first 1M samples
new_dataset = dataset.limit(1_000_000)
Would also welcome any suggested ways of going about this in the best way in the meantime!