shap / shap

A game theoretic approach to explain the output of any machine learning model.
https://shap.readthedocs.io
MIT License
22.48k stars 3.25k forks source link

BUG: Native support for (XGBoost) Categoricals #3813

Closed mattharrison closed 2 weeks ago

mattharrison commented 4 weeks ago

Issue Description

I can find examples of scatter plots that show categoricals yet they are label encoded and use the display_features of shap.dependence_plot to simulate categories.

When I create models with categories (with XGBoost or CatBoost), I use 'category' types for the columns.

This fails if I try to create a scatter plot and view the impact of the category.

Minimal Reproducible Example

import shap
import xgboost

cal_X, cal_y = shap.datasets.adult(n_points=1000, display=True)

xg_cal = xgboost.XGBClassifier(enable_categorical=True)
xg_cal.fit(cal_X, cal_y)

ex_cal = shap.TreeExplainer(xg_cal)
vals_cal = ex_cal(cal_X)    

shap.plots.scatter(vals_cal[:, 'Relationship'])

Traceback

{
    "name": "TypeError",
    "message": "unsupported operand type(s) for -: 'str' and 'str'",
    "stack": "---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[69], line 1
----> 1 shap.plots.scatter(vals_cal[:, 'Relationship'])

File ~/.envs/menv/lib/python3.10/site-packages/shap/plots/_scatter.py:194, in scatter(shap_values, color, hist, axis_color, cmap, dot_size, x_jitter, alpha, title, xmin, xmax, ymin, ymax, overlay, ax, ylabel, show)
    192 min_dist = np.inf
    193 for i in range(1,len(vals)):
--> 194     d = vals[i] - vals[i-1]
    195     if d > 1e-8 and d < min_dist:
    196         min_dist = d

TypeError: unsupported operand type(s) for -: 'str' and 'str'"
}

Expected Behavior

I would love to see a scatterplot like the dependence_plot examples on this page: https://shap.readthedocs.io/en/latest/example_notebooks/tabular_examples/tree_based_models/Census%20income%20classification%20with%20XGBoost.html

Bug report checklist

Installed Versions

0.46.0

mattharrison commented 4 weeks ago

Posting my workaround for when I search for this in the future others.

import seaborn as sns

makes = ['Ford', 'Toyota', 'Honda', 'Tesla']

(pd.DataFrame(vals.values, columns=X_reg.columns)
     .rename(columns=lambda col: f'{col}_shap')
     .assign(base_value=vals.base_values, **X_reg)
     .pipe(lambda df_:
        sns.catplot(x='make', y='make_shap', data=df_, alpha=.5, 
                    hue='year', palette='RdBu',
        order=makes))
)
hypostulate commented 3 weeks ago

This should be solved by #3706.

thatlittleboy commented 2 weeks ago

Resolved by #3706