bgreenwell / fastshap

Fast approximate Shapley values in R
https://bgreenwell.github.io/fastshap/
115 stars 18 forks source link

Shap values don't add up to prediction values #6

Closed prescient closed 4 years ago

prescient commented 4 years ago

I'm not sure if I'm using this correctly, but I would have thought that the sum of the shap values should equal the prediction. In this case they don't seem to:

require(fastshap)
require(ranger)
require(tidyverse)

data(iris)

mdl <- ranger(Sepal.Length ~ . -Species,
              data = iris)
ranger_predict <- function(object, newdata){
  return(predict(object, newdata)$predictions)
}

rng_predictions <- predict(mdl, iris)$predictions

shaps <- explain(object = mdl, 
                 X = iris %>% select(-Species, - Sepal.Length), 
                 pred_wrapper = ranger_predict,
                 nsim = 1)
shap_predictions <- rowSums(shaps)
diffs <- rng_predictions - shap_predictions
plot(diffs)
bgreenwell commented 4 years ago

Hi @prescient sorry for the delay. The exact Shapley values for an instance should add up to the difference between the prediction for that instance and the average predictions across the entire training set. This is not the case, however, for approximate Shapley values, which is the default in fastshap. The sum should converge though as nsim goes to infinity which is briefly demonstrated here: https://bgreenwell.github.io/fastshap/articles/fastshap.html.

I've thought about adding an option to scale the explanations, but not completely sure of what the downside would be.

Does that help?

bgreenwell commented 4 years ago

Linking to this issue: https://github.com/slundberg/shap/issues/998

bgreenwell commented 4 years ago

Just pushed an update (currently experimental). Now you can try setting adjust = TRUE as follows (note, however, that this requires nsim > 1):

require(ranger)
require(tidyverse)

data(iris)

mdl <- ranger(Sepal.Length ~ . -Species,
              data = iris)
ranger_predict <- function(object, newdata){
  return(predict(object, newdata)$predictions)
}

rng_predictions <- predict(mdl, iris)$predictions

shaps <- fastshap::explain(
  object = mdl, 
  X = iris %>% select(-Species, - Sepal.Length), 
  pred_wrapper = ranger_predict,
  nsim = 10,  # needs to be > 1
  adjust = TRUE
)
diffs <- rng_predictions - mean(rng_predictions)
plot(rowSums(shaps), diffs); abline(0, 1)

image

prescient commented 4 years ago

Hey Ben. I appreciate the response. I didn't know if I was just using the package wrong or not. Turns out this was the expected behavior. I'll test out the adjust option. Im often working with stakeholders in the business and things have to add up =]