rapidsai / cuml

cuML - RAPIDS Machine Learning Library
https://docs.rapids.ai/api/cuml/stable/
Apache License 2.0
4.26k stars 535 forks source link

Fix train_test_split for string columns #6088

Closed dantegd closed 1 month ago

dantegd commented 1 month ago

Closes #5834

Before the fix, this was an issue:

import cudf
from cuml.model_selection import train_test_split

SEED = 1
df_a = cudf.DataFrame({'a': [0, 1, 2, 3, 4],
                    'b': [5, 6, 7, 8, 9],
                    'c': ['High', 'Low', 'High', 'High', 'Low']
                   })
target = cudf.Series([1, 1, 1, 0, 0])

# breakpoint()
all_numeric = all(cudf.api.types.is_numeric_dtype(df_a[col]) for col in df_a.columns)
print(all_numeric)
tr, te, ytr, yte = train_test_split(X=df_a, y=target, test_size=0.3, random_state=SEED, stratify=target)

print(tr)
``

would result in multiple errors of the type 

```python
  File "/home/coder/.conda/envs/rapids/lib/python3.12/site-packages/cudf/utils/performance_tracking.py", line 51, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/coder/.conda/envs/rapids/lib/python3.12/site-packages/cudf/core/frame.py", line 358, in _get_columns_by_label
    return self._from_data_like_self(self._data.select_by_label(labels))
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/coder/.conda/envs/rapids/lib/python3.12/site-packages/cudf/core/column_accessor.py", line 401, in select_by_label
    return self._select_by_label_grouped(key)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/coder/.conda/envs/rapids/lib/python3.12/site-packages/cudf/core/column_accessor.py", line 563, in _select_by_label_grouped
    result = self._grouped_data[key]
             ~~~~~~~~~~~~~~~~~~^^^^^
KeyError: '__cuda_array_interface__'

After the fix, train_test_split works for cuDF string columns:

(rapids) coder ➜ ~ $ python cudfstr.py
   a  b     c
3  3  8  High
4  4  9   Low
2  2  7  High
1  1  6   Low

Need to add a test and probably do a small fix for cudf.pandas. There is some redundancy in the code, which can be cleaned as a follow up for a later release to get this is for 24.10.

dantegd commented 1 month ago

@bdice that was code I hadn't touched in this PR except for indenting it and wasn't mine, but was super happy to fix it and clean it!