Open QiyaoWei opened 2 months ago
Hi ! you can do
filtered_dataset = dataset.filter(filter_function)
on a subset:
filtered_subset = dataset.select(range(10_000)).filter(filter_function)
Jumping on this as it seems relevant - when I use the filter
method, it often results in an OOM (or at least unacceptably high memory usage).
For example in the this notebook, we load an object detection dataset from HF and imagine I want to filter such that I only have images which contain a single annotation class. Each row has a JSON field that contains MS-COCO annotations for the image, so we could load that field and filter on it.
The test dataset is only about 440 images, probably less than 1GB, but running the following filter crashes the VM (over 12 GB RAM):
import json
def filter_single_class(example, target_class_id):
"""Filters examples based on whether they contain annotations from a single class.
Args:
example: A dictionary representing a single example from the dataset.
target_class_id: The target class ID to filter for.
Returns:
True if the example contains only annotations from the target class, False otherwise.
"""
if not example['coco_annotations']:
return False
annotation_category_ids = set([annotation['category_id'] for annotation in json.loads(example['coco_annotations'])])
return len(annotation_category_ids) == 1 and target_class_id in annotation_category_ids
target_class_id = 1
filtered_dataset = dataset['test'].filter(lambda example: filter_single_class(example, target_class_id))
Iterating over the dataset works fine:
filtered_dataset = []
for example in dataset['test']:
if filter_single_class(example, target_class_id):
filtered_dataset.append(example)
It would be great if there was guidance in the documentation on how to use filters efficiently, or if this is some performance bug that could be addressed. At the very least I would expect a filter operation to use at most 2x the footprint of the database plus some overhead for the lambda (i.e. worst case would be a duplicate copy with all entries retained). Even if the operation is parallelised, each thread/worker should only take a subset of the dataset - so I'm not sure where this ballooning in memory usage comes from.
From some other comments there seems to be a workaround with writer_batch_size
or caching to file, but in the docs at least, keep_in_memory
defaults to False
.
You can try passing input_columns=["coco_annotations"] to only load this column instead of all the columns. In that case your function should take coco_annotations as input instead of example
If your filter_function is large and computationally intensive, consider using multi-processing or multi-threading with concurrent.futures to filter the dataset. This approach allows you to process multiple tables concurrently, reducing overall processing time, especially for CPU-bound tasks. Use ThreadPoolExecutor for I/O-bound operations and ProcessPoolExecutor for CPU-bound operations.
Feature request
I am not sure if this is a new feature, but I wanted to post this problem here, and hear if others have ways of optimizing and speeding up this process.
Let's say I have a really large dataset that I cannot load into memory. At this point, I am only aware of
streaming=True
to load the dataset. Now, the dataset consists of many tables. Ideally, I would want to have some simple filtering criterion, such that I only see the "good" tables. Here is an example of what the code might look like:And then I work on the processed dataset, which would be magnitudes faster than working on the original. I would love to hear if the problem setup + solution makes sense to people, and if anyone has suggestions!
Motivation
See description above
Your contribution
Happy to make PR if this is a new feature