bgreenwell / fastshap

Fast approximate Shapley values in R
https://bgreenwell.github.io/fastshap/
113 stars 18 forks source link

Example of Shap "summary" plot? #17

Closed jcpsantiago closed 2 years ago

jcpsantiago commented 3 years ago

I really like the summary plot created by the shap python package. Do you have an example snippet for creating it starting from fastshap::explain e.g. for XGBoost?

For reference, this is what I mean: from https://github.com/slundberg/shap

jcpsantiago commented 3 years ago

This might be the answer I'm looking for # 1. Add \code{type = "beeswarm"}. in https://github.com/bgreenwell/fastshap/blob/c96733e478d36c1c0f7c890307e42e11a8f1540c/R/autoplot.R#L3

bgreenwell commented 3 years ago

@jcpsantiago I’ll make sure this gets added in to the next update.

jcpsantiago commented 3 years ago

Here's a snippet from the code I use to create a similar chart. Could be useful.

top_10_feat <- shap %>%
  pivot_longer(everything()) %>%
  group_by(name) %>%
  summarise(v = mean(abs(value))) %>%
  arrange(desc(v)) %>%
  head(10) %>%
  pull(name)

df <- shap %>%
  rename_with(~ paste0(.x, "_shap")) %>%
  pivot_longer(everything(), names_to = "shap_keys", values_to = "shap_values") %>%
  bind_cols(baked_training %>% 
              select(is_fraud, all_of(names(shap))) %>% 
              pivot_longer(2:ncol(.))
            ) %>%
  select(name, value, shap_values, is_fraud) %>%
  filter(!is.na(value) & !is.na(shap_values)) %>%
  group_by(name) %>%
  mutate(scaled = value / max(value)) %>%
  ungroup() %>%
  mutate(name = forcats::fct_reorder(
    as.factor(name),
    .x = shap_values, .fun = function(x) max(abs(x))
  ))

df %>%
  filter(name %in% top_10_feat) %>%
  ggplot(aes(x = name, y = shap_values, color = scaled)) +
  ggforce::geom_sina(alpha = 0.3) +
  coord_flip() +
  scale_colour_viridis_c() +
  hrbrthemes::theme_ipsum_ps() +
  guides(color = guide_colourbar(
    title = "Scaled feature value",
    barwidth = 20, barheight = 0.5, title.position = "top"
  )) +
  theme(legend.position = "bottom") +
  labs(title = "SHAP values for top 10 features", x = "", y = "")
jotech commented 3 years ago

is there already an update to this? It would be really great to have this kind of plot in fastshap :)

jcpsantiago commented 3 years ago

@jotech I've been using the snippet I shared ☝️ in our model cards for the weekly deployments. It's not a single function, but it works :)

jonesworks commented 2 years ago

The trick is using reticulate to access the function directly. Also, if you look at issues on Shap, it would seem like matplotlib 3.2.2. is necessary.

Minimal example, taken from docs.

# Load required packages
library(fastshap)
library(xgboost)
# Load the Boston housing data
# install.packages("pdp)

data(boston, package = "pdp")
X <- data.matrix(subset(boston, select = -cmedv))  # matrix of feature values

# Fit a gradient boosted regression tree ensemble; hyperparameters were tuned 
# using `autoxgb::autoxgb()`
set.seed(859)  # for reproducibility
bst <- xgboost(X, label = boston$cmedv, nrounds = 338, max_depth = 3, eta = 0.1,
               verbose = 0)

# Compute exact explanations for all rows
ex <- explain(bst, exact = TRUE, X = X)

Next, use reticulate

library(reticulate)
shap = import("shap")
np = import("numpy")

shap$dependence_plot("rank(1)", data.matrix(ex), X)
shap$summary_plot(data.matrix(ex), X)

Naming the feature directly threw an error for me: i.e. rank(1) is necessary. Rank(2) and rank(3) would also work. Rendering the plot repeatedly will produce buggy visualizations (that was my experience at least.)

kapsner commented 2 years ago

I've implemented a ggplot2-based beeswarm plot for fastshap's autoplot. You can try it out with the following code:

remotes::install_github("kapsner/fastshap", ref = "feat_beeswarm_plot")

library(doParallel)
library(fastshap)
library(ggplot2)
library(ranger)

boston <- pdp::boston
boston$chas <- as.integer(boston$chas) - 1
X <- data.matrix(subset(boston, select = -cmedv))

# Train a random forest
set.seed(944)  # for reproducibility
rfo <- ranger(cmedv ~ ., data = boston)

# Prediction wrappers
pfun <- function(object, newdata) {
  predict(object, data = newdata)$predictions
}

# Comput approximate Shapley values
set.seed(945)
system.time(
  shap <- fastshap::explain(rfo, X = X, nsim = 50, pred_wrapper = pfun)
)

p <- autoplot(object = shap, type = "beeswarm", X = boston)
p

beeswarm_plot

bgreenwell commented 2 years ago

This is really awesome @kapsner, I’ll take a look!

kapsner commented 2 years ago

FYI the just released R package shapviz also provides a ggplot-based beeswarm plot and other R-native visualizations for shapley values.

bgreenwell commented 2 years ago

@kapsner, does this support fastshap too? I like this idea and have considered moving all the plotting functions, which have a lot more dependencies, into a new package away from the core functionality.

kapsner commented 2 years ago

Yes, they also mention how to create plots from fastshap objects in their documentation.

bgreenwell commented 2 years ago

Closing as this is supported in the shapviz project.