scikit-learn-contrib / imbalanced-learn

A Python Package to Tackle the Curse of Imbalanced Datasets in Machine Learning
https://imbalanced-learn.org
MIT License
6.85k stars 1.29k forks source link

Unable to achieve balanced samples for each estimator (Balanced Bagging Classifier) #1086

Closed chungkae closed 4 months ago

chungkae commented 4 months ago

When I use the official example of BalancedBaggingClassifier , and finally check the sample content of each estimator through estimatorssamples , I find that I cannot get balanced samples. Is this situation reasonable? Or is there anything that needs to be adjusted? Thanks!

here is my code:

from collections import Counter from sklearn.datasets import make_classification from sklearn.model_selection import train_test_split from sklearn.metrics import confusion_matrix from imblearn.ensemble import BalancedBaggingClassifier X, y = make_classification(n_classes=2, class_sep=2, weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) print('Original dataset shape %s' % Counter(y)) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) bbc = BalancedBaggingClassifier(random_state=42) bbc.fit(X_train, y_train) y_pred = bbc.predict(X_test)

for i, samples in enumerate(bbc.estimatorssamples): y_sampled = y_train[samples] counter = Counter(y_sampled) print(f"Estimator {i}: {counter}")

result: Estimator 0: Counter({1: 661, 0: 89}) Estimator 1: Counter({1: 674, 0: 76}) Estimator 2: Counter({1: 662, 0: 88}) Estimator 3: Counter({1: 687, 0: 63}) Estimator 4: Counter({1: 670, 0: 80}) Estimator 5: Counter({1: 676, 0: 74}) Estimator 6: Counter({1: 671, 0: 79}) Estimator 7: Counter({1: 665, 0: 85}) Estimator 8: Counter({1: 671, 0: 79}) Estimator 9: Counter({1: 662, 0: 88})

chkoar commented 4 months ago

Yes, your code and results demonstrate the intended behavior. The distribution modification is being carried out at the classifier level. Here are the changes I made to your code for better understanding.

from collections import Counter
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from imblearn.ensemble import BalancedBaggingClassifier
from sklearn import tree

class DecisionTreeClassifier(tree.DecisionTreeClassifier):
    def fit(self, X, y):
        self.y_ = y
        return super().fit(X, y)

X, y = make_classification(
    n_classes=2,
    class_sep=2,
    weights=[0.1, 0.9],
    n_informative=3,
    n_redundant=1,
    flip_y=0,
    n_features=20,
    n_clusters_per_class=1,
    n_samples=1000,
    random_state=10,
)
print("Original dataset shape %s" % Counter(y))
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
bbc = BalancedBaggingClassifier(DecisionTreeClassifier(), random_state=42)
bbc.fit(X_train, y_train)
y_pred = bbc.predict(X_test)

for i, samples in enumerate(bbc.estimators_samples_):
    y_sampled = y_train[samples]
    counter = Counter(y_sampled)
    print(f"Estimator {i}: {counter}")
    print(f"Delivered distribution {i}: {Counter(bbc.estimators_[i][1].y_)}")

Original dataset shape Counter({1: 900, 0: 100}) Estimator 0: Counter({1: 661, 0: 89}) Delivered distribution 0: Counter({0: 89, 1: 89}) Estimator 1: Counter({1: 674, 0: 76}) Delivered distribution 1: Counter({0: 76, 1: 76}) Estimator 2: Counter({1: 662, 0: 88}) Delivered distribution 2: Counter({0: 88, 1: 88}) Estimator 3: Counter({1: 687, 0: 63}) Delivered distribution 3: Counter({0: 63, 1: 63}) Estimator 4: Counter({1: 670, 0: 80}) Delivered distribution 4: Counter({0: 80, 1: 80}) Estimator 5: Counter({1: 676, 0: 74}) Delivered distribution 5: Counter({0: 74, 1: 74}) Estimator 6: Counter({1: 671, 0: 79}) Delivered distribution 6: Counter({0: 79, 1: 79}) Estimator 7: Counter({1: 665, 0: 85}) Delivered distribution 7: Counter({0: 85, 1: 85}) Estimator 8: Counter({1: 671, 0: 79}) Delivered distribution 8: Counter({0: 79, 1: 79}) Estimator 9: Counter({1: 662, 0: 88}) Delivered distribution 9: Counter({0: 88, 1: 88})

chungkae commented 4 months ago

@chkoar Thanks for your detailed and informative answer.