shap / shap

A game theoretic approach to explain the output of any machine learning model.
https://shap.readthedocs.io
MIT License
22.06k stars 3.21k forks source link

BUG: Unexpected Interaction Plot Instead of Summary Plot in Multiclass SHAP Summary with XGBoost #3630

Open cconsta1 opened 2 months ago

cconsta1 commented 2 months ago

Issue Description

When attempting to use SHAP with an XGBoost multiclass classification model to generate summary plots, the output unexpectedly appears as an interaction plot rather than the anticipated summary plot. This issue occurs when trying to visualize the SHAP values for all classes simultaneously.

Minimal Reproducible Example

import xgboost as xgb
import shap
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
from sklearn.metrics import accuracy_score

# Generate synthetic data
X, y = make_classification(n_samples=500, n_features=20, n_informative=4, n_classes=6, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)

# Train an XGBoost model for multiclass classification
model = xgb.XGBClassifier(objective="multi:softprob", random_state=42)
model.fit(X_train, y_train)

# Create a SHAP TreeExplainer
explainer = shap.TreeExplainer(model)

# Calculate SHAP values for the test set
shap_values = explainer.shap_values(X_test)

# Attempt to plot summary for all classes
shap.summary_plot(shap_values, X_test, plot_type="bar")

Traceback

No response

Expected Behavior

The expected outcome is a summary plot that shows the feature importance for all classes in a clear and aggregated manner.

Bug report checklist

Installed Versions

SHAP version: 0.45.0 Python version: 3.10.12 XGBoost version: 2.0.3 Operating System: Google Colab Pro

wiktorolszowy commented 2 months ago

It is not XGBoost-specific, as I have the same problem with SHAP values derived from CatBoost and LightGBM models. It is related to shap.summary_plot.

mengwang-mw commented 1 month ago

I have encountered the same issue - with multiclass output, the summary_plot function generates interaction plot while the summary bar plot is expected.

I manually fixed this issue by going to their source code and change the data type of their TreeExplainer output from numpy array to list.

Here is what I did in detail: I went to https://github.com/shap/shap/blob/master/shap/explainers/_tree.py and commented lines 515-516. After that, I successfully generated the summary plot with multi-class output.

This error was due to the change in version 0.45.0 - they changed the output from list to numpy array, as can be seen in lines 410-411 of file https://github.com/shap/shap/blob/master/shap/explainers/_tree.py, so I reversed this change to fix the issue.

wiktorolszowy commented 1 month ago

Well spotted! I think for the majority of cases, a shortcut with a C++ implementation of Tree SHAP is used, so these 2 lines need to be commented out too (the same data transformation as in the lines you pointed to):

https://github.com/shap/shap/blob/86d8bc58a42e9e11901ad506f5c27f55fa4f0349/shap/explainers/_tree.py#L478C1-L479C49

Commenting these lines out most likely has some side effects, but without these lines the SHAP summary plot indeed works for multi-class classification models. Thanks!

Omranic commented 3 weeks ago

I encountered the same problem, and switching back to version 0.44.1 resolved it for me.

Below is a straightforward code to demonstrate the issue:

# Create a synthetic dataset
X, y = make_classification(n_samples=100, n_features=5, n_informative=3, n_redundant=1, n_clusters_per_class=1, n_classes=3, random_state=42)
features = [f"Feature {i}" for i in range(X.shape[1])]
X = pd.DataFrame(X, columns=features)

# Train a RandomForest model
model = RandomForestClassifier(n_estimators=50, random_state=42)
model.fit(X, y)

# Create the SHAP Explainer
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)

# Plot SHAP values for each class
shap.summary_plot(shap_values, X, plot_type="bar", class_names=['Class 0', 'Class 1', 'Class 2'])

Here are the screenshots for both versions:

Screenshot 2024-06-10 at 11 15 50 AM copy Screenshot 2024-06-10 at 11 32 27 AM copy

cconsta1 commented 1 week ago

@Omranic switching back to version 0.44.1 was the solution I went for myself. Thank you guys for responding tot this issue!