bgreenwell / fastshap

Fast approximate Shapley values in R
https://bgreenwell.github.io/fastshap/
112 stars 18 forks source link

random survival forests #40

Closed mao223 closed 1 year ago

mao223 commented 2 years ago

Good day. Is it possible to get SHAP values for random survival forest models? For instance, those generated from randomForestSRC? Thank you very much.

bgreenwell commented 2 years ago

Hi @mao223, yes it is. The way to go would via the pred_wrapper argument which tells fastshap::explain() how to obtain predictions from your fitted model. This is especially necessary for a random survival forest because you can use the model to predict several quantities. The Titanic example here shows how to use this argument. If you post a small reproducible package using that package, I'd be happy to help with a more direct example.

mao223 commented 2 years ago

Thank you so much for your response. Here is an example of a random survival forest implementation using randomForestSRC:

library(randomForestSRC)
data(peakVO2, package = "randomForestSRC")
dta <- peakVO2
obj <- rfsrc(Surv(ttodead,died)~., dta,
             ntree = 1000, nodesize = 5, nsplit = 50, importance = TRUE)
print(obj)

>  1                          Sample size: 2231
>  2                     Number of deaths: 726
>  3                      Number of trees: 1000
>  4            Forest terminal node size: 5
>  5        Average no. of terminal nodes: 259.368
>  6 No. of variables tried at each split: 7
>  7               Total no. of variables: 39
>  8        Resampling used to grow trees: swor
>  9     Resample size used to grow trees: 1410
> 10                             Analysis: RSF
> 11                               Family: surv
> 12                       Splitting rule: logrank *random*
> 13        Number of random split points: 50
> 14                           (OOB) CRPS: 0.15476339
> 15    (OOB) Requested performance error: 0.29830498

Based on your suggestion, I implemented the following

# Data frame containing just the features
X <- subset(dta, select = -c(died, ttodead))
pfun <- function(object, newdata) {
  predict(object, newdata = newdata)$predicted
}

# calculate shap values for each individual
require(fastshap)
ex <- fastshap::explain(obj, X = X, pred_wrapper = pfun, nsim = 100,
                  newdata = X)

Is this the correct thing to do? Thank you very much for your guidance.

bgreenwell commented 1 year ago

Sorry for the delay @mao223, but yes, as long as those are the predictions of interest!