oegedijk / explainerdashboard

Quickly build Explainable AI dashboards that show the inner workings of so-called "blackbox" machine learning models.
http://explainerdashboard.readthedocs.io
MIT License
2.32k stars 332 forks source link

Sklearn PLS Regression incompatibility with ExplainerDashboard #258

Open shyam-bayer opened 1 year ago

shyam-bayer commented 1 year ago

I would like to use PLS regression with the Explainer dashboard package. However, it throws an error which I can't address. It seems like there is a compatibility issue. Could you please confirm if PLS regression is compatible or not.

Below is my script:

from sklearn.cross_decomposition import PLSRegression from sklearn.datasets import load_diabetes from explainerdashboard import ExplainerDashboard, RegressionExplainer import numpy as np from sklearn import linear_model diabetes_X, diabetes_y = load_diabetes(as_frame=True, return_X_y=True) regr = PLSRegression(n_components=2) regr.fit(diabetes_X_flat, diabetes_y) explainer = RegressionExplainer(regr, diabetes_X, diabetes_y) db = ExplainerDashboard(explainer)

I am getting the following error: Building ExplainerDashboard.. Detected notebook environment, consider setting mode='external', mode='inline' or mode='jupyterlab' to keep the notebook interactive while the dashboard is running... For this type of model and model_output interactions don't work, so setting shap_interaction=False... The explainer object has no decision_trees property. so setting decision_trees=False... Generating layout... Calculating shap values... 100%|██████████| 442/442 [01:22<00:00, 5.33it/s]

ValueError Traceback (most recent call last) Cell In[38], line 1 ----> 1 db = ExplainerDashboard(explainer)

File /usr/local/lib/python3.10/site-packages/explainerdashboard/dashboards.py:795, in ExplainerDashboard.init(self, explainer, tabs, title, name, description, simple, hide_header, header_hide_title, header_hide_selector, header_hide_download, hide_poweredby, block_selector_callbacks, pos_label, fluid, mode, width, height, bootstrap, external_stylesheets, server, url_base_pathname, routes_pathname_prefix, requests_pathname_prefix, responsive, logins, port, importances, model_summary, contributions, whatif, shap_dependence, shap_interaction, decision_trees, kwargs) 793 if isinstance(tabs, list): 794 tabs = [self._convert_str_tabs(tab) for tab in tabs] --> 795 self.explainer_layout = ExplainerTabsLayout( 796 explainer, 797 tabs, 798 title, 799 description=self.description, 800 update_kwargs( 801 kwargs, 802 header_hide_title=self.header_hide_title, 803 header_hide_selector=self.header_hide_selector, 804 header_hide_download=self.header_hide_download, 805 hide_poweredby=self.hide_poweredby, 806 block_selector_callbacks=self.block_selector_callbacks, 807 pos_label=self.pos_label, 808 fluid=fluid, 809 ), 810 ) 811 else: 812 tabs = self._convert_str_tabs(tabs)

File /usr/local/lib/python3.10/site-packages/explainerdashboard/dashboards.py:111, in ExplainerTabsLayout.init(self, explainer, tabs, title, name, description, header_hide_title, header_hide_selector, header_hide_download, hide_poweredby, block_selector_callbacks, pos_label, fluid, kwargs) 108 self.fluid = fluid 110 self.selector = PosLabelSelector(explainer, name="0", pos_label=pos_label) --> 111 self.tabs = [ 112 instantiate_component(tab, explainer, name=str(i + 1), kwargs) 113 for i, tab in enumerate(tabs) 114 ] 115 assert ( 116 len(self.tabs) > 0 117 ), "When passing a list to tabs, need to pass at least one valid tab!" 119 self.register_components(*self.tabs)

File /usr/local/lib/python3.10/site-packages/explainerdashboard/dashboards.py:112, in (.0) 108 self.fluid = fluid 110 self.selector = PosLabelSelector(explainer, name="0", pos_label=pos_label) 111 self.tabs = [ --> 112 instantiate_component(tab, explainer, name=str(i + 1), *kwargs) 113 for i, tab in enumerate(tabs) 114 ] 115 assert ( 116 len(self.tabs) > 0 117 ), "When passing a list to tabs, need to pass at least one valid tab!" 119 self.register_components(self.tabs)

File /usr/local/lib/python3.10/site-packages/explainerdashboard/dashboard_methods.py:890, in instantiate_component(component, explainer, name, kwargs) 884 kwargs = { 885 k: v 886 for k, v in kwargs.items() 887 if k in init_argspec.args + init_argspec.kwonlyargs 888 } 889 if "name" in init_argspec.args + init_argspec.kwonlyargs: --> 890 component = component(explainer, name=name, kwargs) 891 else: 892 print( 893 f"ExplainerComponent {component} does not accept a name parameter, " 894 f"so cannot assign name='{name}': " (...) 899 "cluster will generate its own random uuid name!" 900 )

File /usr/local/lib/python3.10/site-packages/explainerdashboard/dashboard_components/composites.py:413, in RegressionModelStatsComposite.init(self, explainer, title, name, hide_title, hide_modelsummary, hide_predsvsactual, hide_residuals, hide_regvscol, logs, pred_or_actual, residuals, col, kwargs) 403 self.preds_vs_actual = PredictedVsActualComponent( 404 explainer, name=self.name + "0", logs=logs, kwargs 405 ) 406 self.residuals = ResidualsComponent( 407 explainer, 408 name=self.name + "1", (...) 411 kwargs, 412 ) --> 413 self.reg_vs_col = RegressionVsColComponent( 414 explainer, name=self.name + "2", logs=logs, kwargs 415 )

File /usr/local/lib/python3.10/site-packages/explainerdashboard/dashboard_components/regression_components.py:1676, in RegressionVsColComponent.init(self, explainer, title, name, subtitle, hide_title, hide_subtitle, hide_footer, hide_col, hide_ratio, hide_points, hide_winsor, hide_cats_topx, hide_cats_sort, hide_popout, col, display, round, points, winsor, cats_topx, cats_sort, plot_sample, description, **kwargs) 1673 super().init(explainer, title, name) 1675 if self.col is None: -> 1676 self.col = self.explainer.columns_ranked_by_shap()[0] 1678 assert self.display in { 1679 "observed", 1680 "predicted", (...) 1686 f" but you passed display={self.display}!" 1687 ) 1689 if self.description is None:

File /usr/local/lib/python3.10/site-packages/explainerdashboard/explainers.py:66, in insert_pos_label..inner(self, *args, kwargs) 63 @wraps(func) 64 def inner(self, *args, *kwargs): 65 if not self.is_classifier: ---> 66 return func(self, args, kwargs) 67 if "pos_label" in kwargs: 68 if kwargs["pos_label"] is not None: 69 # ensure that pos_label is int

File /usr/local/lib/python3.10/site-packages/explainerdashboard/explainers.py:1310, in BaseExplainer.columns_ranked_by_shap(self, pos_label) 1298 @insert_pos_label 1299 def columns_ranked_by_shap(self, pos_label=None): 1300 """returns the columns of X, ranked by mean abs shap value 1301 1302 Args: (...) 1308 1309 """ -> 1310 return self.mean_abs_shap_df(pos_label).Feature.tolist()

File /usr/local/lib/python3.10/site-packages/explainerdashboard/explainers.py:66, in insert_pos_label..inner(self, *args, kwargs) 63 @wraps(func) 64 def inner(self, *args, *kwargs): 65 if not self.is_classifier: ---> 66 return func(self, args, kwargs) 67 if "pos_label" in kwargs: 68 if kwargs["pos_label"] is not None: 69 # ensure that pos_label is int

File /usr/local/lib/python3.10/site-packages/explainerdashboard/explainers.py:1287, in BaseExplainer.mean_abs_shap_df(self, pos_label) 1284 """Mean absolute SHAP values per feature.""" 1285 if not hasattr(self, "_mean_abs_shap_df"): 1286 self._mean_abs_shap_df = ( -> 1287 self.get_shap_values_df(pos_label)[self.merged_cols] 1288 .abs() 1289 .mean() 1290 .sort_values(ascending=False) 1291 .to_frame() 1292 .rename_axis(index="Feature") 1293 .reset_index() 1294 .rename(columns={0: "MEAN_ABS_SHAP"}) 1295 ) 1296 return self._mean_abs_shap_df

File /usr/local/lib/python3.10/site-packages/explainerdashboard/explainers.py:66, in insert_pos_label..inner(self, *args, kwargs) 63 @wraps(func) 64 def inner(self, *args, *kwargs): 65 if not self.is_classifier: ---> 66 return func(self, args, kwargs) 67 if "pos_label" in kwargs: 68 if kwargs["pos_label"] is not None: 69 # ensure that pos_label is int

File /usr/local/lib/python3.10/site-packages/explainerdashboard/explainers.py:1151, in BaseExplainer.get_shap_values_df(self, pos_label) 1144 self._shap_values_df = pd.DataFrame( 1145 self.shap_explainer.shap_values( 1146 torch.tensor(self.X.values), self.shap_kwargs 1147 ), 1148 columns=self.columns, 1149 ) 1150 else: -> 1151 self._shap_values_df = pd.DataFrame( 1152 self.shap_explainer.shap_values(self.X, self.shap_kwargs), 1153 columns=self.columns, 1154 ) 1155 self._shap_values_df = merge_categorical_shap_values( 1156 self._shap_values_df, self.onehot_dict, self.merged_cols 1157 ).astype(self.precision) 1158 return self._shap_values_df

File /usr/local/lib/python3.10/site-packages/pandas/core/frame.py:762, in DataFrame.init(self, data, index, columns, dtype, copy) 754 mgr = arrays_to_mgr( 755 arrays, 756 columns, (...) 759 typ=manager, 760 ) 761 else: --> 762 mgr = ndarray_to_mgr( 763 data, 764 index, 765 columns, 766 dtype=dtype, 767 copy=copy, 768 typ=manager, 769 ) 770 else: 771 mgr = dict_to_mgr( 772 {}, 773 index, (...) 776 typ=manager, 777 )

File /usr/local/lib/python3.10/site-packages/pandas/core/internals/construction.py:329, in ndarray_to_mgr(values, index, columns, dtype, copy, typ) 324 values = values.reshape(-1, 1) 326 else: 327 # by definition an array here 328 # the dtypes will be coerced to a single dtype --> 329 values = _prep_ndarraylike(values, copy=copy_on_sanitize) 331 if dtype is not None and not is_dtype_equal(values.dtype, dtype): 332 # GH#40110 see similar check inside sanitize_array 333 rcf = not (is_integer_dtype(dtype) and values.dtype.kind == "f")

File /usr/local/lib/python3.10/site-packages/pandas/core/internals/construction.py:583, in _prep_ndarraylike(values, copy) 581 values = values.reshape((values.shape[0], 1)) 582 elif values.ndim != 2: --> 583 raise ValueError(f"Must pass 2-d input. shape={values.shape}") 585 return values

ValueError: Must pass 2-d input. shape=(1, 442, 10)

List of packages: Package Version


ansi2html 1.8.0 asttokens 2.2.1 attrs 22.2.0 backcall 0.2.0 certifi 2022.12.7 charset-normalizer 3.1.0 click 8.1.3 cloudpickle 2.2.1 colour 0.1.5 comm 0.1.3 contourpy 1.0.7 cycler 0.11.0 dash 2.9.2 dash-auth 2.0.0 dash-bootstrap-components 1.4.1 dash-core-components 2.0.0 dash-html-components 2.0.0 dash-table 5.0.0 debugpy 1.6.6 decorator 5.1.1 dtreeviz 2.2.0 exceptiongroup 1.1.1 executing 1.2.0 explainerdashboard 0.4.2.1 Flask 2.2.3 flask-simplelogin 0.1.1 Flask-WTF 0.15.1 fonttools 4.39.3 graphviz 0.20.1 idna 3.4 iniconfig 2.0.0 ipykernel 6.22.0 ipython 8.12.0 itsdangerous 2.1.2 jedi 0.18.2 Jinja2 3.1.2 joblib 1.2.0 jupyter_client 8.1.0 jupyter_core 5.3.0 jupyter-dash 0.4.2 kiwisolver 1.4.4 llvmlite 0.39.1 MarkupSafe 2.1.2 matplotlib 3.7.1 matplotlib-inline 0.1.6 nest-asyncio 1.5.6 numba 0.56.4 numpy 1.23.5 oyaml 1.0 packaging 23.0 pandas 1.5.3 parso 0.8.3 pexpect 4.8.0 pickleshare 0.7.5 Pillow 9.4.0 pip 22.3.1 platformdirs 3.2.0 plotly 5.14.0 pluggy 1.0.0 prompt-toolkit 3.0.38 psutil 5.9.4 ptyprocess 0.7.0 pure-eval 0.2.2 Pygments 2.14.0 pyparsing 3.0.9 pytest 7.2.2 python-dateutil 2.8.2 pytz 2023.3 PyYAML 6.0 pyzmq 25.0.2 requests 2.28.2 retrying 1.3.4 scikit-learn 1.2.2 scipy 1.10.1 setuptools 67.4.0 shap 0.41.0 six 1.16.0 slicer 0.0.7 stack-data 0.6.2 tenacity 8.2.2 threadpoolctl 3.1.0 tomli 2.0.1 tornado 6.2 tqdm 4.65.0 traitlets 5.9.0 urllib3 1.26.15 waitress 2.1.2 wcwidth 0.2.6 Werkzeug 2.2.3 wheel 0.38.4 WTForms 3.0.1

oegedijk commented 1 year ago

Hi @shyam-bayer, my guess is due the fact that PLSRegression does not return a single prediction but can return multiple components. This means that model.predict(X_test) will return a two-dimensional numpy array instead of a single dimensional one, which results in the errors.

oegedijk commented 1 year ago

this seems to fix it, at least as long the PLSRegressor has a single component: https://github.com/oegedijk/explainerdashboard/commit/3f4a9a0e8f1e63dd0f0ef829f1f0fd0ad357aa6d

oegedijk commented 1 year ago

so should be in the next release

shyam-bayer commented 1 year ago

Thanks for looking into this. However, Like any other regression model, PLSRegression also returns single "ypred" output if there is only one "y" column. PLSRegression can use multiple latent variables but that will not effect shape of "ypred" vector/matrix.