oegedijk / explainerdashboard

Quickly build Explainable AI dashboards that show the inner workings of so-called "blackbox" machine learning models.
http://explainerdashboard.readthedocs.io
MIT License
2.26k stars 324 forks source link

shap_values should be 2d, instead shape=(200, 21, 2)! #297

Open harshil17 opened 3 months ago

harshil17 commented 3 months ago

I am running the sample code same as it's given here https://github.com/oegedijk/explainerdashboard, using titanic datasource.

And running into the error saying "shap_values should be 2d, instead shape=(200, 21, 2)!"

Attached is the full error trace. can pleas anyone help me understand why i am getting this error and how can i resolve it ?

`AssertionError Traceback (most recent call last) Cell In[7], line 12 1 explainer = ClassifierExplainer(model, X_test, y_test, 2 cats=['Deck', 'Embarked', 3 {'Gender': ['Sex_male', 'Sex_female', 'Sex_nan']}], (...) 9 target = "Survival", # defaults to y.name 10 ) ---> 12 db = ExplainerDashboard(explainer, 13 title="Titanic Explainer", # defaults to "Model Explainer" 14 shap_interaction=False, # you can switch off tabs with bools 15 ) 16 db.run(port=8050)

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\dashboards.py:803, in ExplainerDashboard.init(self, explainer, tabs, title, name, description, simple, hide_header, header_hide_title, header_hide_selector, header_hide_download, hide_poweredby, block_selector_callbacks, pos_label, fluid, mode, width, height, bootstrap, external_stylesheets, server, url_base_pathname, routes_pathname_prefix, requests_pathname_prefix, responsive, logins, port, importances, model_summary, contributions, whatif, shap_dependence, shap_interaction, decision_trees, kwargs) 801 if isinstance(tabs, list): 802 tabs = [self._convert_str_tabs(tab) for tab in tabs] --> 803 self.explainer_layout = ExplainerTabsLayout( 804 explainer, 805 tabs, 806 title, 807 description=self.description, 808 update_kwargs( 809 kwargs, 810 header_hide_title=self.header_hide_title, 811 header_hide_selector=self.header_hide_selector, 812 header_hide_download=self.header_hide_download, 813 hide_poweredby=self.hide_poweredby, 814 block_selector_callbacks=self.block_selector_callbacks, 815 pos_label=self.pos_label, 816 fluid=fluid, 817 ), 818 ) 819 else: 820 tabs = self._convert_str_tabs(tabs)

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\dashboards.py:119, in ExplainerTabsLayout.init(self, explainer, tabs, title, name, description, header_hide_title, header_hide_selector, header_hide_download, hide_poweredby, block_selector_callbacks, pos_label, fluid, kwargs) 116 self.fluid = fluid 118 self.selector = PosLabelSelector(explainer, name="0", pos_label=pos_label) --> 119 self.tabs = [ 120 instantiate_component(tab, explainer, name=str(i + 1), kwargs) 121 for i, tab in enumerate(tabs) 122 ] 123 assert ( 124 len(self.tabs) > 0 125 ), "When passing a list to tabs, need to pass at least one valid tab!" 127 self.register_components(*self.tabs)

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\dashboards.py:120, in (.0) 116 self.fluid = fluid 118 self.selector = PosLabelSelector(explainer, name="0", pos_label=pos_label) 119 self.tabs = [ --> 120 instantiate_component(tab, explainer, name=str(i + 1), *kwargs) 121 for i, tab in enumerate(tabs) 122 ] 123 assert ( 124 len(self.tabs) > 0 125 ), "When passing a list to tabs, need to pass at least one valid tab!" 127 self.register_components(self.tabs)

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\dashboard_methods.py:890, in instantiate_component(component, explainer, name, kwargs) 884 kwargs = { 885 k: v 886 for k, v in kwargs.items() 887 if k in init_argspec.args + init_argspec.kwonlyargs 888 } 889 if "name" in init_argspec.args + init_argspec.kwonlyargs: --> 890 component = component(explainer, name=name, kwargs) 891 else: 892 print( 893 f"ExplainerComponent {component} does not accept a name parameter, " 894 f"so cannot assign name='{name}': " (...) 899 "cluster will generate its own random uuid name!" 900 )

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\dashboard_components\composites.py:545, in IndividualPredictionsComposite.init(self, explainer, title, name, hide_predindexselector, hide_predictionsummary, hide_contributiongraph, hide_pdp, hide_contributiontable, hide_title, hide_selector, index_check, kwargs) 538 self.summary = RegressionPredictionSummaryComponent( 539 explainer, hide_selector=hide_selector, kwargs 540 ) 542 self.contributions = ShapContributionsGraphComponent( 543 explainer, hide_selector=hide_selector, kwargs 544 ) --> 545 self.pdp = PdpComponent( 546 explainer, name=self.name + "3", hide_selector=hide_selector, kwargs 547 ) 548 self.contributions_list = ShapContributionsTableComponent( 549 explainer, hide_selector=hide_selector, **kwargs 550 ) 552 self.index_connector = IndexConnector( 553 self.index, 554 [self.summary, self.contributions, self.pdp, self.contributions_list], 555 explainer=explainer if index_check else None, 556 )

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\dashboard_components\overview_components.py:639, in PdpComponent.init(self, explainer, title, name, subtitle, hide_col, hide_index, hide_title, hide_subtitle, hide_footer, hide_selector, hide_popout, hide_dropna, hide_sample, hide_gridlines, hide_gridpoints, hide_cats_sort, index_dropdown, feature_input_component, pos_label, col, index, dropna, sample, gridlines, gridpoints, cats_sort, description, **kwargs) 636 self.index_name = "pdp-index-" + self.name 638 if self.col is None: --> 639 self.col = self.explainer.columns_ranked_by_shap()[0] 641 if self.feature_input_component is not None: 642 self.exclude_callbacks(self.feature_input_component)

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\explainers.py:86, in insert_pos_label..inner(self, *args, kwargs) 84 else: 85 kwargs.update(dict(pos_label=self.pos_label)) ---> 86 return func(self, kwargs)

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\explainers.py:1318, in BaseExplainer.columns_ranked_by_shap(self, pos_label) 1306 @insert_pos_label 1307 def columns_ranked_by_shap(self, pos_label=None): 1308 """returns the columns of X, ranked by mean abs shap value 1309 1310 Args: (...) 1316 1317 """ -> 1318 return self.mean_abs_shap_df(pos_label).Feature.tolist()

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\explainers.py:86, in insert_pos_label..inner(self, *args, kwargs) 84 else: 85 kwargs.update(dict(pos_label=self.pos_label)) ---> 86 return func(self, kwargs)

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\explainers.py:3128, in ClassifierExplainer.mean_abs_shap_df(self, pos_label) 3126 """mean absolute SHAP values""" 3127 if not hasattr(self, "_mean_abs_shapdf"): -> 3128 = self.get_shap_values_df() 3129 self._mean_abs_shap_df = [ 3130 self.get_shap_values_df(pos_label)[self.merged_cols] 3131 .abs() (...) 3138 for pos_label in self.labels 3139 ] 3140 return self._mean_abs_shap_df[pos_label]

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\explainers.py:86, in insert_pos_label..inner(self, *args, kwargs) 84 else: 85 kwargs.update(dict(pos_label=self.pos_label)) ---> 86 return func(self, kwargs)

File I:\Explainer Dashboard\explainer-dashboard\lib\site-packages\explainerdashboard\explainers.py:2845, in ClassifierExplainer.get_shap_values_df(self, pos_label) 2843 if len(self.labels) == 2: 2844 if not isinstance(_shap_values, list): -> 2845 assert ( 2846 len(_shap_values.shape) == 2 2847 ), f"shap_values should be 2d, instead shape={_shap_values.shape}!" 2848 elif isinstance(_shap_values, list) and len(_shap_values) == 2: 2849 # for binary classifier only keep positive class 2850 _shap_values = _shap_values[1]

AssertionError: shap_values should be 2d, instead shape=(200, 21, 2)!`

Brian-AlphaPlay commented 3 months ago

I have the same error as well. Running code that had no changes to it and worked fine before.

EDIT: Looks to be caused by "breaking change" in the newest version of shap: https://github.com/shap/shap/releases Issue does not appear when downgrading to shap==0.44.1

oegedijk commented 3 months ago

Yes, tests are failing as well. Seems like the output shape of the shap library has changed. Will look into it...

oegedijk commented 3 months ago

okay, have the fix on master. Will see if I can release tomorrow...

oegedijk commented 3 months ago

okay it's released: 0.4.6. I actually had to change the github pypi release mechanism, so I would appreciate it if you could let me know that it worked!

harshil17 commented 3 months ago

Thanks, So i just tried with 0.4.6 and it's still giving me the same error ?

oegedijk commented 3 months ago

pip install -U explainerdashboard should install version 0.4.6 or are you using conda? (it takes about a day for the conda-forge CI to pick up the latest version and release it)

harshil17 commented 3 months ago

Yes, I am using conda. I see. I will wait until tomorrow then and see if it works.

harshil17 commented 3 months ago

I do see though that i am upgraded to 0.4.6.

oegedijk commented 3 months ago

Should also be on conda now. What model are you using? Can you run the classifier example from the README?

harshil17 commented 3 months ago

just checked with classifier example from readme with updated version of explainerdashboard and it's still giving me the same error.

harshil17 commented 3 months ago

Nevermind, i just restart the jupyter and it worked. Thank you very much.