iancovert / sage

For calculating global feature importance using Shapley values.
MIT License
255 stars 34 forks source link

Shape mismatch on XGB.Classifier #7

Closed strumke closed 3 years ago

strumke commented 3 years ago

Hi. I hope and suspect this is just a misunderstanding on my part, but I cannot mage sage work using the XGBoost Classifer. I have included a minimal example below, using the Boston Housing dataset. Here, the XGBRegressor works fine but the XGBClassifier produces an error, although y_pred has the same shape in both cases. The error* arises in utils.py in call(self, pred, target), line 159, probably because the if clause on line 156 isn't triggered so pred doesn't get reshaped (?). I'm not sure but would appreciate your input. And thanks for a great package :-)

*ValueError: operands could not be broadcast together with shapes (512,2) (512,)

from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import xgboost as xgb
import sage
boston = load_boston()
boston_dataset = pd.DataFrame(boston.data, columns=boston.feature_names)
boston_dataset['MEDV'] =  boston.target
features = ["RM", "AGE", "TAX", "CRIM", "PTRATIO"]
x_data = np.array(boston_dataset[features])
medv = np.array(boston_dataset.MEDV)
mean = np.mean(medv)
y_data = np.array([1 if _m < mean else 0 for _m in medv]) # make targets binary
x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.33, random_state=42)
#model = xgb.XGBClassifier(use_label_encoder=False).fit(x_train, y_train) # Doesn't work
model = xgb.XGBRegressor().fit(x_train, y_train) # Works
y_pred = model.predict(x_test)
print(y_pred.shape)
imputer = sage.MarginalImputer(model, x_train[:512])
estimator = sage.PermutationEstimator(imputer, 'mse')
sage_values = estimator(x_test, y_test)
sage_values.plot(features)
iancovert commented 3 years ago

Hi Inga, thanks for posting. I've checked out your code and there's just one small problem: SAGE is told to use 'mse' as a loss function, but for classification models it's better to use 'cross entropy' (in the third to last line). With this change, the code snippet runs properly!

It's a little concerning though that it fails with MSE, because it should still be able to run... The issue is that, under the hood, SAGE detects that we're using a classification model (at this line), and predictions are then made using the predict_proba function. This returns an array of size (batch, num_classes), or (batch, 2) in this case.

What do you think, are you okay with the solution of having to use cross entropy loss for classification models? I would make the choice of loss function automatic to avoid potential confusion, but I could imagine someone wanting to use 0-1 accuracy rather than cross entropy. Or perhaps there could just be a more informative error message. Let me know what you think.

strumke commented 3 years ago

Hi Ian, thanks for the response! The workaround works for me, but in case users have strong preferences regarding loss function - or want to implement their own, I don't know if you'd like to provide that functionality - perhaps it would be possible to come up with a permanent solution in the form of a check (?). I'm not sure, but doing predict_proba()[:,0] would at least yield the right shape. (I don't know what interpretation this would correspond to for SAGE, though, I'm not sufficiently familiar with it yet :) ) Regarding a more informative error message, it would have been useful to know that this is a predict/predict_proba problem. Thanks a lot again!

iancovert commented 3 years ago

Okay, so my solution for now is to add a more specific error message that clarifies the dimensions of the pred and target variables. With this, hopefully people will have enough information to realize that the model is making probabilistic predictions.

Two ideas that I decided against:

  1. Constraining the choice of loss function seemed too restrictive, because people may actually want to use MSE for classification models in some situations.
  2. Adding a print function to say which prediction function is being used (e.g., predict vs. predict_proba for sklearn classifiers) is a bit clunky, and I thought it would be ugly to do that whenever someone is setting up an estimator.

I'm going to close the issue for now. Hopefully this works, but we'll see if other people mention this problem!