shap / shap

A game theoretic approach to explain the output of any machine learning model.
https://shap.readthedocs.io
MIT License
22.89k stars 3.29k forks source link

BUG: shap.plots.bar(shap_values) TypeError #3673

Open JeroenWalchenbach opened 5 months ago

JeroenWalchenbach commented 5 months ago

Issue Description

when using barplot function from shap a TypeError occurs when calling line 260 ax.set_yticks.

set_ticks has no attribute fontsize.

Minimal Reproducible Example

import shap 
shap.initjs()

model = xg.XGBRegressor()
model.fit(X_train, y_train)
explainer = shap.Explainer(model)
shap_values = explainer(X_train)

shap.plots.bar(shap_values)

Traceback

Traceback (most recent call last)
/var/folders/pk/yjvh2jtj15ldhmcxb6_ss_ch0000gn/T/ipykernel_11569/768057057.py in <module>
----> 1 shap.plots.bar(shap_values, ax=None)

/opt/anaconda3/lib/python3.9/site-packages/shap/plots/_bar.py in bar(shap_values, max_display, order, clustering, clustering_cutoff, show_data, ax, show)
    258 
    259     # draw the yticks (the 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks)
--> 260     ax.set_yticks(list(y_pos) + list(y_pos + 1e-8), yticklabels + [t.split('=')[-1] for t in yticklabels], fontsize=13)
    261 
    262     xlen = ax.get_xlim()[1] - ax.get_xlim()[0]

/opt/anaconda3/lib/python3.9/site-packages/matplotlib/axes/_base.py in wrapper(self, *args, **kwargs)
     71 
     72         def wrapper(self, *args, **kwargs):
---> 73             return get_method(self)(*args, **kwargs)
     74 
     75         wrapper.__module__ = owner.__module__

TypeError: set_ticks() got an unexpected keyword argument 'fontsize'

Expected Behavior

No response

Bug report checklist

Installed Versions

0.45.1

thatlittleboy commented 3 months ago

hi @JeroenWalchenbach , your example is not reproducible. (X_train, y_train are not defined).

This works, for example:

import shap
import xgboost as xg
shap.initjs()

X_train, y_train = shap.datasets.adult(n_points=500)
model = xg.XGBRegressor()
model.fit(X_train, y_train)
explainer = shap.Explainer(model)
shap_values = explainer(X_train)

shap.plots.bar(shap_values)

So it probably has something to do with how your data is defined.