Open bilashda opened 2 years ago
This issue has been inactive for two years, so it's been automatically marked as 'stale'.
We value your input! If this issue is still relevant, please leave a comment below. This will remove the 'stale' label and keep it open.
If there's no activity in the next 90 days the issue will be closed.
I have implemented a multi-class classification using GradientBoostClassifer which has 5 labels(0..4). I want to have the decision plot for a specific observations and in this case with index=2(observation #3). However, the following implementation gives the error ValueError: The base_values and shap_values args expect lists. Could somebody suggest the corrections to be made and the reason.
GBM_explainer = shap.KernelExplainer(finalGBMModel.predict_proba, X_class_train) GBM_shap_values = GBM_explainer.shap_values(X_class_test)
Change the the index value between 0-4 for each class labels to generate the class specific plot
shap.summary_plot(GBM_shap_values[4], X_class_test,show = False) plt.title('GradientBoost SHAP Summary for Class \'4\'')
19.3 Plot the SHAP Decision Plot for a specific data point
the base value is the value where each class starts with
GBM_base_values = GBM_explainer.expected_value ## the base value is the value where each class starts with if isinstance(GBM_base_values, list): GBM_base_values = GBM_base_values[1] print(f"Explainer expected value: {GBM_base_values}") if isinstance(GBM_shap_values, list): GBM_shap_values = GBM_shap_values[1]
Function to generate the labels for plot the plot legend along with class prediction value
class_count = len(GBM_base_values) def class_labels(row_index): return [f'Class {i} ({gbmPredictProb[row_index, i].round(2):.2f})' for i in range(class_count)]
Choose the row index of the test data set for which the decision plot to be produced
row_index = 2 ## 3rd data point but Different row index can be mentioned here
Plot the decision plot calling the SHAP Multioutput API
shap.multioutput_decision_plot(GBM_base_values, GBM_shap_values, row_index=row_index, feature_names=X_class_test.columns.tolist(), highlight=[np.argmax(gbmPredictProb[row_index])],#[np.argmax(gbmPrediction[row_index])] legend_labels=class_labels(row_index), legend_location='lower right',show = False) #[np.argmax(gbmPredictProb[row_index])] plt.title("Decision Plot for Data Point :" + str(row_index + 1))