scikit-learn-contrib / imbalanced-learn

A Python Package to Tackle the Curse of Imbalanced Datasets in Machine Learning
https://imbalanced-learn.org
MIT License
6.85k stars 1.29k forks source link

ValueError: Found array with dim 4. RandomOverSampler expected <= 2 #966

Closed LHXqwq closed 1 year ago

LHXqwq commented 1 year ago

I want to perform OverSampler on the image classification task, but the result shows "ValueError: Found array with dim 4. RandomOverSampler expected <= 2." How can I use imbalanced-learn?

vitaliset commented 1 year ago

Most scikit-learn compatible stuff only accepts a 2D matrix for X with a shape of (n_samples, n_features). Nonetheless, you can work it around with reshape function from numpy,

Note that we can go back and forth with the shapes of an array using it:

import numpy as np

X = np.zeros((10, 20, 30, 40))

original_shape = X.shape
print(original_shape)
>>> (10, 20, 30, 40)

X_reshaped = X.reshape(original_shape[0], -1)
# Note that 20*30*40 = 24000
print(X_reshaped.shape)
>>> (10, 24000)

X_undo_reshape = X_reshaped.reshape(original_shape)
print(X_undo_reshape.shape)
>>> (10, 20, 30, 40)

print((X_undo_reshape == X).all())
>>> True

So we can use the same trick before giving X to the RandomOverSampler:

from imblearn.over_sampling import RandomOverSampler

X = np.array([[[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]],
             [[[13, 14], [15, 16], [17, 18]], [[19, 20], [21, 22], [23, 24]]],
             [[[25, 26], [27, 28], [29, 30]], [[31, 32], [33, 34], [35, 36]]]])
y = np.array([0, 0, 1])
print(X.shape)
>>> (3, 2, 3, 2)

X_over, y_over = RandomOverSampler().fit_resample(X, y)
>>> ValueError: Found array with dim 4. RandomOverSampler expected <= 2.

X_reshaped = X.reshape(X.shape[0], -1)
X_over, y_over = RandomOverSampler().fit_resample(X_reshaped, y)

return_shape = list(X.shape)
# Now we have one extra sample drawn from the RandomOverSampler, so we have to modify the original shape.
return_shape[0] = X_over.shape[0]
X_over = X_over.reshape((return_shape))

print(X_over.shape, y_over.shape)
>>> (4, 2, 3, 2) (4,)
glemaitre commented 1 year ago

imbalanced-learn only supports tabular data. We will not this as a bug.