oegedijk / explainerdashboard

Quickly build Explainable AI dashboards that show the inner workings of so-called "blackbox" machine learning models.
http://explainerdashboard.readthedocs.io
MIT License
2.29k stars 330 forks source link

Catboost classifier to work with ClassifierExplainer #59

Closed yanhong-zhao-ef closed 3 years ago

yanhong-zhao-ef commented 3 years ago

Naively passed in catboost classifier in the ClassifierExplainer leads to this error message:

Exception: Currently TreeExplainer can only handle models with categorical splits when feature_perturbation="tree_path_dependent" and no background data is passed. Please try again using shap.TreeExplainer(model, feature_perturbation="tree_path_dependent").

After digging into the source code of shap here https://github.com/slundberg/shap/blob/master/shap/explainers/_tree.py before they go on to calculate the shape values, they seem to have done something to correct the conditional sampling for SHAP and CatBoost categorical variables are not supported for now so these lines in the explainer might need a bit of tweaking to yield something like this: self._shap_explainer = shap.TreeExplainer(self.model) where internals of the SHAP tree explainer will handle Catboost out of the box.

A few fixes pop up to mind:

  1. expose these inputs for users to set model_output or feature_perturbation 2.have a special CatBoost classifier class.

I will hot fix this issue in my fork and continue on. Maybe I will discover some new caveats that warrant a new class (for tree plotting perhaps)

oegedijk commented 3 years ago

Hej,

You are right: catboost support is currently not very robust in that I assume that X still consist of only numerical data, and categorical features have been onehot-encoded. This obviously largely misses the point of using catboost in the first place :)

So looking at the shap source code, it seems they get the shap values directly from the catboost model (who implemented the shap algorithm directly into the model I guess):

elif self.model.model_type == "catboost": # thanks to the CatBoost team for implementing this...
                assert not approximate, "approximate=True is not supported for CatBoost models!"
                assert tree_limit == -1, "tree_limit is not yet supported for CatBoost models!"
                import catboost
                if type(X) != catboost.Pool:
                    X = catboost.Pool(X, cat_features=self.model.cat_feature_indices)
                phi = self.model.original_model.get_feature_importance(data=X, fstr_type='ShapValues')

If that gets the shap values even for categorical columns, then we can implement the same trick in explainerdashboard.

Currently I calculate self.X_cats and self.shap_values_cats etc, where the onehot encoded columns have been merged into a categorical columns, and the shap values added up. For catboost model we would just have to make X_cats the default and set some flag cats_only=True or something.

Would require a bit of plumbing (lots of if self.cats_only: cats = True in the various methods), but should be doable.

If you could implement something similar to get_xgboost_preds_df for catboost then we could add the marginal tree plot similar to xgboost to the library. And by implementing get_xgboost_path_df we could add the tree path table. For the tree graph we would still be dependent on dtreeviz support though.

yanhong-zhao-ef commented 3 years ago

@oegedijk thanks for the response.

That sounds like a good solution for me and it is quite a bit of plumbing to do. I was looking into Shap source code for the first time yesterday and it does seem quite a bit hassle to support models like xgboost, catboost and others :)

Let me know if you have time to work on that.

I will have a look at the xgboost methods and see if I can get something together for catboost!

oegedijk commented 3 years ago

Okay, got it mostly to work with a few exceptions:

When I detect a CatBoost model with 'object' columns in X, I set explainer.cats_only=True. This makes sure that shap values get calculated correctly. Then in the ExplainerDashboard I set cats=True, hide_cats=True, hide_pdp=True, hide_whatifpdp=True, shap_interaction=False.

So it should work without the user having to do anything.

Will have to write some tests, and probably build in some more safeguards here and there. But for now it seems to work.

oegedijk commented 3 years ago

Could you check out the dev branch https://github.com/oegedijk/explainerdashboard/tree/dev and see if it works for you?

yanhong-zhao-ef commented 3 years ago

Cool will check it out later today and let you know! thanks @oegedijk

oegedijk commented 3 years ago

Also found a fix to pdpbox but will have to get the PR accepted, and it hasn't been maintained for three years, so that might take a while: https://github.com/oegedijk/PDPbox

yanhong-zhao-ef commented 3 years ago

the fix worked! the X_cats detection works lovely for catboost! Thanks @oegedijk !

oegedijk commented 3 years ago

I refactored the pdp code (basically rewrote the relevant bits of the PDPbox library into my own functions), so now pdp is also working, and refactored the code some more so that you can also combine categorical and onehot cols...

oegedijk commented 3 years ago

Released it in the new version: https://github.com/oegedijk/explainerdashboard/releases/tag/v0.2.20

Should now work with any model that support categorical features.

yanhong-zhao-ef commented 3 years ago

Thanks @oegedijk will try out the new release!

Tanay0510 commented 2 years ago

I am building a catboost model using

Clf = CatBoostClassifier( iteration = 50, random_seed = 42, learning_rate = 0.1)

Clf.fit(X_train, y_train, cat_features = cat_features_names, eval_set = (X_val, y_val), verbose = False, plot = True)

Now

explainer = ClassifierExplainer(clf, X_val, y_val, cv = 2)

Db = ExplainerDashboard(explainer, title = “XYZ”, whatif = True, decision_trees = True, cats= True, hide_cats = True, hide_pdp = True, hide_whatifpdp  = True, shap_interaction = False)

Db.run(port=8051)

I am doing the following and every time I am getting the following error -

CatBoostError: features data: pandas.DataFrame column “A” has dtype ‘category’ but is not in cat_features list.

When I am calling clf.fit(..) I am providing the names of columns there. I don’t know why I am facing this error and I am a bit lost. Any help would be great.