dask / dask-ml

Scalable Machine Learning with Dask
http://ml.dask.org
BSD 3-Clause "New" or "Revised" License
903 stars 256 forks source link

Unexpected behavior in train_test_split with shuffle=False #992

Open divir94 opened 6 months ago

divir94 commented 6 months ago

When using train_test_split with shuffle=False and a Dask dataframe, I notice 2 issues - 1) The index is actually shuffled and 2) the train/test size seems incorrect. The behavior doesn't match sklearn or when you pass a raw DataFrame.

Minimal Complete Verifiable Example: Setup

import pandas as pd
import numpy as np
import dask.dataframe as dd

from sklearn.model_selection import train_test_split as sk_train_test_split
from dask_ml.model_selection import train_test_split as dd_train_test_split

df = pd.DataFrame(np.random.rand(10, 3), columns=["y", "x1", "x2"])
ddf = dd.from_pandas(df, 5)

With sklearn.model_selection, order is maintained (i.e. no shuffle)

y = df["y"]
X = df[["x1", "x2"]]

X_train, X_valid, y_train, y_test = sk_train_test_split(X, y, test_size=0.5, shuffle=False)
y_train, y_test
Output:
(0    0.166713
 1    0.961016
 2    0.483907
 3    0.979503
 4    0.553724
 Name: y, dtype: float64,
 5    0.158432
 6    0.078795
 7    0.440427
 8    0.673160
 9    0.657797
 Name: y, dtype: float64)

With dask_ml.model_selection using Pandas Dataframe, order is maintained (i.e. no shuffle)

y = df["y"]
X = df[["x1", "x2"]]

X_train, X_valid, y_train, y_test = dd_train_test_split(X, y, test_size=0.5, shuffle=False)
y_train, y_test
(0    0.166713
 1    0.961016
 2    0.483907
 3    0.979503
 4    0.553724
 Name: y, dtype: float64,
 5    0.158432
 6    0.078795
 7    0.440427
 8    0.673160
 9    0.657797
 Name: y, dtype: float64)

With dask_ml.model_selection using Dask Dataframe, , order is NOT maintained and train/test size is incorrect.

y = ddf["y"]
X = ddf[["x1", "x2"]]

X_train, X_valid, y_train, y_test = dd_train_test_split(X, y, test_size=0.5, shuffle=False)
y_train.compute(), y_test.compute()
(0    0.166713
 1    0.961016
 2    0.483907
 3    0.979503
 8    0.673160
 9    0.657797
 Name: y, dtype: float64,
 4    0.553724
 5    0.158432
 6    0.078795
 7    0.440427
 Name: y, dtype: float64)

Environment: