ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.61k stars 5.71k forks source link

[Data] Add `stratify` argument in Dataset's train_test_split #34849

Open GokuMohandas opened 1 year ago

GokuMohandas commented 1 year ago

When dealing with imbalanced datasets, it's important to be able to stratify on a specific column's values (similar to scikit-learn). Right now, can only naively split which may lead to heavily skewed splits if data has large imbalances.

# Split dataset
train_ds, val_ds = ds.train_test_split(test_size=0.3, stratify="tag")
GokuMohandas commented 1 year ago

Just FYI, a temporary function I'm using groups the dataset by the column I wish to stratify on (group_by) and use a UDF to map each batch dataframe (map_groups) and split into train & test splits (and add a text column specifying which). Then we filter on all the groups in by the train or test string column we created earlier to create the stratified train & test splits. In this implementation, stratify is a string but I think it's more robust to have it be a general array-like input (but v1 can just be a string column name input if this can be achieved for 2.5 release!)

And can we ensure that reproducibility is possible here by defining some configurations? (even at a loss at optimization). Setting ray.data.DatasetContext.get_current().execution_options.preserve_order = True doesn't work with my approach above so I assume something additional is needed.

stale[bot] commented 1 year ago

Hi, I'm a bot from the Ray team :)

To help human contributors to focus on more relevant issues, I will automatically add the stale label to issues that have had no activity for more than 4 months.

If there is no further activity in the 14 days, the issue will be closed!

You can always ask for help on our discussion forum or Ray's public slack channel.