koaning / scikit-lego

Extra blocks for scikit-learn pipelines.
https://koaning.github.io/scikit-lego/
MIT License
1.28k stars 117 forks source link

[BUG] GroupedPredictor inconsistency for predict_proba having different classes per group #579

Closed fabioscantamburlo closed 1 year ago

fabioscantamburlo commented 1 year ago

Hello scikit-lego users. While using predict_proba paired with GroupedPredictor and a classifier on a df with different labels per group, the final matrix is collapsed to the left, without caring about the label order. This yields inconsistencies in the final output, especially while using an high number of labels. To achieve a sound result every label should appear at least once in every group, that is somehow unrealistic.

Here a snippet of code:

import pandas as pd
import numpy as np

from sklego.meta import GroupedPredictor
from sklearn.linear_model import LogisticRegression

np.random.seed(43)

group_size = 5
n_groups = 2
df = pd.DataFrame({
    "group": ["A"] * group_size + ["B"] * group_size,
    "x": np.random.normal(size=group_size * n_groups),
    "y": np.hstack([
        np.random.choice([0, 1, 2], size=group_size),
        np.random.choice([0, 2], size=group_size),
        ])
})

print(df.groupby('group').agg({'y': set}))

X, y = df[["x", "group"]], df["y"]
model = GroupedPredictor(LogisticRegression(), groups=["group"])
_ = model.fit(X, y)
y_prob = model.predict_proba(X)

print(y_prob.round(2))

Outputs:


>>>                y
>>> group           
>>> A      {0, 1, 2}
>>> B         {0, 2}

>>> [[0.45 0.19 0.36]#grp A
>>> [0.3  0.23 0.47]
>>> [0.37 0.21 0.42]
>>> [0.35 0.22 0.44]
>>> [0.53 0.16 0.31]
>>> [0.79 0.21  nan]#grp B
>>> [0.8  0.2   nan]
>>> [0.81 0.19  nan]
>>> [0.81 0.19  nan]
>>> [0.79 0.21  nan]]

# Expected:
>>> [[0.45 0.19 0.36]#grp A
>>> [0.3  0.23 0.47]
>>> [0.37 0.21 0.42]
>>> [0.35 0.22 0.44]
>>> [0.53 0.16 0.31]
>>> [0.79 nan  0.21]#grp B
>>> [0.8  nan   0.2]
>>> [0.81 nan  0.19]
>>> [0.81 nan  0.19]
>>> [0.79 nan  0.21]]
FBruzzesi commented 1 year ago

Hey @fabioscantamburlo, thanks for reporting the bug. At the moment there is no internal checking for these edge cases but it may be worth looking into it and adding such mechanism

fabioscantamburlo commented 1 year ago

I would like to work on this if possible.

FBruzzesi commented 1 year ago

Glad to hear that! Looking forward to a PR to address this issue😊