linkedin / FastTreeSHAP

Fast SHAP value computation for interpreting tree-based models
BSD 2-Clause "Simplified" License
500 stars 30 forks source link

Plotting example #11

Closed jpfeil closed 1 year ago

jpfeil commented 1 year ago

Thanks for making FastTreeShap! I'm excited to use it. I want to generate some classic SHAP plots. Can you provide an example, maybe in a jupyter notebook, of how to use your plotting library to visualize the SHAP values.

jlyang1990 commented 1 year ago

Thanks for your interest in FastTreeSHAP! The plotting functions in FastTreeSHAP are exactly the same as in SHAP (https://github.com/slundberg/shap), so any visualization tools introduced in SHAP should be directly applicable in FastTreeSHAP. Let me know if there is any mismatch.

junaid1990 commented 1 year ago

Sir. I want to ask. How Can we visualize bar plot Represents normalized mean absolute SHAP value across all the folds for the RF, GB, and XGB model training. Or interpret the comparison of different models at a single bar / shap dependence / summary plots? I have shared some examples.

Screenshot_20221212-024328_Adobe Acrobat

Screenshot_20221212-023553_Chrome Screenshot_20221212-023645_Chrome

jlyang1990 commented 1 year ago

I don't think FastTreeSHAP or SHAP package can visualize the above plots. I think the best way is to make the plot functions by yourself, so that it is easy for customization and maintenance.

edmoman commented 1 year ago

I am not able to plot a beeswarm as explained below if I use fasttreeshap.TreeExplainer instead of shap.Explainer:

https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/beeswarm.html

The beewarm plotter complains about the output not being a shap object.

For instance, this does not work:

explainer = fasttreeshap.TreeExplainer(classifier, algorithm='auto', n_jobs=-1)
shap_values = explainer(X)
shap.plots.beeswarm(shap_values)

The workaround I have found so far is using the legacy summary_plot (which seems to accept shap values instead of an object):

shap_values = fasttreeshap.TreeExplainer(classifier, algorithm='auto', n_jobs=-1).shap_values(X)
shap.summary_plot(shap_values, X)

That works.

jlyang1990 commented 1 year ago

@edmoman The commit https://github.com/linkedin/FastTreeSHAP/commit/b08edb76b1a06c4b2a2ad3c9042fedc2ee373473 should fix the beeswarm plotting issue. Let me know if this issue still exists.