Closed strumke closed 3 years ago
Hi Inga, thanks for posting. I've checked out your code and there's just one small problem: SAGE is told to use 'mse'
as a loss function, but for classification models it's better to use 'cross entropy'
(in the third to last line). With this change, the code snippet runs properly!
It's a little concerning though that it fails with MSE, because it should still be able to run... The issue is that, under the hood, SAGE detects that we're using a classification model (at this line), and predictions are then made using the predict_proba
function. This returns an array of size (batch, num_classes)
, or (batch, 2)
in this case.
What do you think, are you okay with the solution of having to use cross entropy loss for classification models? I would make the choice of loss function automatic to avoid potential confusion, but I could imagine someone wanting to use 0-1 accuracy rather than cross entropy. Or perhaps there could just be a more informative error message. Let me know what you think.
Hi Ian, thanks for the response! The workaround works for me, but in case users have strong preferences regarding loss function - or want to implement their own, I don't know if you'd like to provide that functionality - perhaps it would be possible to come up with a permanent solution in the form of a check (?). I'm not sure, but doing predict_proba()[:,0] would at least yield the right shape. (I don't know what interpretation this would correspond to for SAGE, though, I'm not sufficiently familiar with it yet :) ) Regarding a more informative error message, it would have been useful to know that this is a predict/predict_proba problem. Thanks a lot again!
Okay, so my solution for now is to add a more specific error message that clarifies the dimensions of the pred
and target
variables. With this, hopefully people will have enough information to realize that the model is making probabilistic predictions.
Two ideas that I decided against:
predict
vs. predict_proba
for sklearn classifiers) is a bit clunky, and I thought it would be ugly to do that whenever someone is setting up an estimator.I'm going to close the issue for now. Hopefully this works, but we'll see if other people mention this problem!
Hi. I hope and suspect this is just a misunderstanding on my part, but I cannot mage sage work using the XGBoost Classifer. I have included a minimal example below, using the Boston Housing dataset. Here, the XGBRegressor works fine but the XGBClassifier produces an error, although y_pred has the same shape in both cases. The error* arises in utils.py in call(self, pred, target), line 159, probably because the if clause on line 156 isn't triggered so pred doesn't get reshaped (?). I'm not sure but would appreciate your input. And thanks for a great package :-)
*ValueError: operands could not be broadcast together with shapes (512,2) (512,)