scikit-learn-contrib / imbalanced-learn

A Python Package to Tackle the Curse of Imbalanced Datasets in Machine Learning
https://imbalanced-learn.org
MIT License
6.85k stars 1.29k forks source link

[BUG] `_transform_one` fails on sparse DataFrame #1053

Closed ts2095 closed 10 months ago

ts2095 commented 1 year ago

Describe the bug

Creating a DataFrame again after generating additional samples, fails when the original DataFrame was sparse. pd.DataFrame.sparse.from_spmatrix is required instead of pd.DataFrame.

Steps/Code to Reproduce

from imblearn.over_sampling import RandomOverSampler
import pandas as pd

df = pd.DataFrame({"a": [0, 1] * 10, "b": [0, 1] * 10}, dtype=pd.SparseDtype(float, 0))
y = pd.Series([0] * 18 + [1] * 2)

ros = RandomOverSampler(sampling_strategy=1, random_state=42, shrinkage=1)
ros.fit_resample(df, y)

Expected Results

Code should run through without issues.

Actual Results

[...]
  File "[...]/site-packages/imblearn/utils/_validation.py", line 39, in transform
    X = self._transfrom_one(X, self.x_props)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/site-packages/imblearn/utils/_validation.py", line 64, in _transfrom_one
    ret = pd.DataFrame(array, columns=props["columns"])
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "[...]/site-packages/pandas/core/frame.py", line 798, in __init__
    mgr = ndarray_to_mgr(
          ^^^^^^^^^^^^^^^
  File "[...]/site-packages/pandas/core/internals/construction.py", line 337, in ndarray_to_mgr
    _check_values_indices_shape_match(values, index, columns)
  File "[...]/site-packages/pandas/core/internals/construction.py", line 408, in _check_values_indices_shape_match
    raise ValueError(f"Shape of passed values is {passed}, indices imply {implied}")
ValueError: Shape of passed values is (26, 1), indices imply (26, 2)

Versions

          pip: 23.3
   setuptools: 68.2.2
        numpy: 1.26.2
        scipy: 1.11.3
       Cython: None
       pandas: 2.0.3
   matplotlib: 3.8.1
       joblib: 1.3.2
threadpoolctl: 3.2.0
Built with OpenMP: True
threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 12
         prefix: libopenblas
       filepath: [...]/site-packages/numpy/.dylibs/libopenblas64_.0.dylib
        version: 0.3.23.dev
threading_layer: pthreads
   architecture: armv8
       user_api: openmp
   internal_api: openmp
    num_threads: 12
         prefix: libomp
       filepath: [...]/site-packages/sklearn/.dylibs/libomp.dylib
        version: None
       user_api: blas
   internal_api: openblas
    num_threads: 12
         prefix: libopenblas
       filepath: [...]/site-packages/scipy/.dylibs/libopenblas.0.dylib
        version: 0.3.21.dev
threading_layer: pthreads
   architecture: armv8
macOS-14.1-arm64-arm-64bit
Python 3.11.5 (main, Sep 11 2023, 08:31:25) [Clang 14.0.6 ]
NumPy 1.26.2
SciPy 1.11.3
Scikit-Learn 1.3.0
Imbalanced-Learn 0.11.0