ModelOriented / kernelshap

Different SHAP algorithms
https://modeloriented.github.io/kernelshap/
GNU General Public License v2.0
39 stars 7 forks source link

kernelshap crashes with too many trees in random forest #131

Closed shanepon closed 6 months ago

shanepon commented 6 months ago

I am trying to implement kernelshap to compute SHAP from a parsnip-censored random forest model using the partykit engine, but are encountering a situation where R will abort the session based mostly on how large the forest is (i.e., in terms of trees) and how large the bg_X is. For example, the R sessions terminates when trying to run kernelshap if the trees argument in rand_forest() is too large (e.g., trees = 1000). I tried the same process with the aorsf engine instead of partykit and did not encounter the same issue even with 10000 trees. So I suspect the issue arises from the size/structure of the partykit model and how it’s handled by kernelshap, but I am not sure on the specific details. I can’t provide a reprex of the R terminal crash because the terminal crashes.

reprex crash

I get the following exit code (1073741819) from my code editor when it crashes, but don’t know what it means. “The terminal process "C:\Users\pgaut\AppData\Local\Programs\Python\Python312\Scripts\radian.exe '--no-save', '--no-restore', '--r-binary=C:\Program Files\R\R-4.3.3\bin\x64\R.exe'" terminated with exit code: -1073741819.”

I’ve tried this using R-studio, as well as a plain R console and get the same error.

I have included a reprex of the process with fewer trees which does run to completion. On my system, increasing the ntrees to 1000 leads to the crash.

Are there any suggestions on how to overcome this issue?

# HOUSEKEEPING ####
rm(list = ls(all = TRUE)) # clean house
# CRAN libraries
library(tidyverse) # install.packages("tidyverse")
library(tidymodels) # install.packages("tidymodels") # pak::pak("tidymodels")
library(censored) # install.packages("censored") # pak::pak("tidymodels/censored") # pak::pak("tidymodels/parsnip")  # pak::pak("tidymodels/tune")
#> Loading required package: survival
library(kernelshap) # pak::pak("ModelOriented/kernelshap")
library(reprex) # pak::pak("reprex")
# Custom operators, functions, and datasets
unregister_dopar <- function() {
    env <- foreach:::.foreachGlobals
    rm(list = ls(name = env), pos = env)
}

# GENERATE DATA ####
set.seed(42)
n_samples <- 303
n_features <- 33

data <- matrix(runif(n_samples * n_features, min = 0.1, max = 5), nrow = n_samples, ncol = n_features) %>%
    data.frame() %>%
    tibble() %>%
    mutate(
        Status = sample(c(0, 1), n_samples, replace = TRUE, prob = c(0.9, 0.1)),
        Time = sample(c(1:1095), n_samples, replace = TRUE),
        surv = Surv(Time, Status),
        .before = 1
    ) %>%
    add_rowindex() %>%
    relocate(.row)

# INITIAL SPLIT ####
set.seed(42)
df_resampling <- data %>% initial_split(strata = Status)
training <- df_resampling %>% training()
testing <- df_resampling %>% testing()

# BUILD MODELS ####
mod_rsf <- rand_forest(
    mtry = 5,
    trees = 50,
    min_n = 10
) %>%
    set_engine("partykit") %>%
    set_mode("censored regression")

# DEFINE RECIPE ####
recipe_surv <- recipes::recipe(surv ~ ., data = data) %>%
    recipes::update_role(c(.row, Status, Time), new_role = "ID")

# DEFINE WORKFLOW ####
workflow_surv <- workflows::workflow() %>%
    workflows::add_recipe(recipe_surv) %>%
    workflows::add_model(mod_rsf)

# FIT rand_forest ####
fit <- workflow_surv %>% parsnip::fit(data = training)
fit_parsnip <- fit %>% extract_fit_parsnip()

# DEFINE CUSTOM PREDICT FUNCTION FOR THE PARSNIP SURVIVAL FUNCTION ####
custom_pred_rsf_parsnip <- function(model, newdata, times) {
    parsnip::predict_survival(
        object = model,
        new_data = newdata,
        type = "survival",
        eval_time = times
    ) %>%
        tidyr::unnest(.pred) %>%
        dplyr::filter(.eval_time %in% times) %>%
        dplyr::select(.pred_survival)
}

# DEFINE FEATURES ####
features <- data %>%
    dplyr::select(-.row, -surv, -Status, -Time) %>%
    colnames()

# TEST KERNELSHAP ####
cl <- parallel::makeCluster(18, type = "PSOCK")
doParallel::registerDoParallel(cl)

set.seed(42)
res_tmp <- kernelshap::kernelshap(
    object = fit_parsnip,
    X = testing,
    bg_X = training %>% slice_sample(n = 50),
    pred_fun = custom_pred_rsf_parsnip,
    times = c(1095),
    feature_names = features
    # parallel = TRUE,
    # parallel_args = list(.packages = c("tidyverse", "tidymodels", "censored"))
)
#> Kernel SHAP values by the hybrid strategy of degree 1
#>   |                                                                              |                                                                      |   0%  |                                                                              |=                                                                     |   1%  |                                                                              |==                                                                    |   3%  |                                                                              |===                                                                   |   4%  |                                                                              |====                                                                  |   5%  |                                                                              |=====                                                                 |   7%  |                                                                              |======                                                                |   8%  |                                                                              |======                                                                |   9%  |                                                                              |=======                                                               |  11%  |                                                                              |========                                                              |  12%  |                                                                              |=========                                                             |  13%  |                                                                              |==========                                                            |  14%  |                                                                              |===========                                                           |  16%  |                                                                              |============                                                          |  17%  |                                                                              |=============                                                         |  18%  |                                                                              |==============                                                        |  20%  |                                                                              |===============                                                       |  21%  |                                                                              |================                                                      |  22%  |                                                                              |=================                                                     |  24%  |                                                                              |==================                                                    |  25%  |                                                                              |==================                                                    |  26%  |                                                                              |===================                                                   |  28%  |                                                                              |====================                                                  |  29%  |                                                                              |=====================                                                 |  30%  |                                                                              |======================                                                |  32%  |                                                                              |=======================                                               |  33%  |                                                                              |========================                                              |  34%  |                                                                              |=========================                                             |  36%  |                                                                              |==========================                                            |  37%  |                                                                              |===========================                                           |  38%  |                                                                              |============================                                          |  39%  |                                                                              |=============================                                         |  41%  |                                                                              |=============================                                         |  42%  |                                                                              |==============================                                        |  43%  |                                                                              |===============================                                       |  45%  |                                                                              |================================                                      |  46%  |                                                                              |=================================                                     |  47%  |                                                                              |==================================                                    |  49%  |                                                                              |===================================                                   |  50%  |                                                                              |====================================                                  |  51%  |                                                                              |=====================================                                 |  53%  |                                                                              |======================================                                |  54%  |                                                                              |=======================================                               |  55%  |                                                                              |========================================                              |  57%  |                                                                              |=========================================                             |  58%  |                                                                              |=========================================                             |  59%  |                                                                              |==========================================                            |  61%  |                                                                              |===========================================                           |  62%  |                                                                              |============================================                          |  63%  |                                                                              |=============================================                         |  64%  |                                                                              |==============================================                        |  66%  |                                                                              |===============================================                       |  67%  |                                                                              |================================================                      |  68%  |                                                                              |=================================================                     |  70%  |                                                                              |==================================================                    |  71%  |                                                                              |===================================================                   |  72%  |                                                                              |====================================================                  |  74%  |                                                                              |====================================================                  |  75%  |                                                                              |=====================================================                 |  76%  |                                                                              |======================================================                |  78%  |                                                                              |=======================================================               |  79%  |                                                                              |========================================================              |  80%  |                                                                              |=========================================================             |  82%  |                                                                              |==========================================================            |  83%  |                                                                              |===========================================================           |  84%  |                                                                              |============================================================          |  86%  |                                                                              |=============================================================         |  87%  |                                                                              |==============================================================        |  88%  |                                                                              |===============================================================       |  89%  |                                                                              |================================================================      |  91%  |                                                                              |================================================================      |  92%  |                                                                              |=================================================================     |  93%  |                                                                              |==================================================================    |  95%  |                                                                              |===================================================================   |  96%  |                                                                              |====================================================================  |  97%  |                                                                              |===================================================================== |  99%  |                                                                              |======================================================================| 100%

# FREE WORKERS AND RESOLVE ENVIRONMENT ####
parallel::stopCluster(cl)
unregister_dopar()

Created on 2024-04-18 with reprex v2.0.2

Standard output and standard error ``` sh -- nothing to show -- ```
Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.3.3 (2024-02-29 ucrt) #> os Windows 11 x64 (build 22631) #> system x86_64, mingw32 #> ui RTerm #> language (EN) #> collate English_Canada.utf8 #> ctype English_Canada.utf8 #> tz America/Edmonton #> date 2024-04-18 #> pandoc 3.1.1 @ C:/Program Files/RStudio/resources/app/bin/quarto/bin/tools/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date (UTC) lib source #> backports 1.4.1 2021-12-13 [1] CRAN (R 4.3.0) #> broom * 1.0.5 2023-06-09 [1] CRAN (R 4.3.1) #> censored * 0.3.0 2024-01-31 [1] CRAN (R 4.3.3) #> class 7.3-22 2023-05-03 [2] CRAN (R 4.3.3) #> cli 3.6.1 2023-03-23 [1] CRAN (R 4.3.1) #> codetools 0.2-19 2023-02-01 [2] CRAN (R 4.3.3) #> coin 1.4-2 2021-10-08 [1] CRAN (R 4.3.1) #> colorspace 2.1-0 2023-01-23 [1] CRAN (R 4.3.1) #> data.table 1.14.8 2023-02-17 [1] CRAN (R 4.3.1) #> dials * 1.2.0 2023-04-03 [1] CRAN (R 4.3.2) #> DiceDesign 1.10 2023-12-07 [1] CRAN (R 4.3.2) #> digest 0.6.33 2023-07-07 [1] CRAN (R 4.3.1) #> doParallel 1.0.17 2022-02-07 [1] CRAN (R 4.3.1) #> dplyr * 1.1.2 2023-04-20 [1] CRAN (R 4.3.1) #> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.3.1) #> evaluate 0.21 2023-05-05 [1] CRAN (R 4.3.1) #> fansi 1.0.4 2023-01-22 [1] CRAN (R 4.3.1) #> fastmap 1.1.1 2023-02-24 [1] CRAN (R 4.3.1) #> forcats * 1.0.0 2023-01-29 [1] CRAN (R 4.3.1) #> foreach 1.5.2 2022-02-02 [1] CRAN (R 4.3.1) #> Formula 1.2-5 2023-02-24 [1] CRAN (R 4.3.0) #> fs 1.6.3 2023-07-20 [1] CRAN (R 4.3.1) #> furrr 0.3.1 2022-08-15 [1] CRAN (R 4.3.3) #> future 1.33.0 2023-07-01 [1] CRAN (R 4.3.1) #> future.apply 1.11.0 2023-05-21 [1] CRAN (R 4.3.1) #> generics 0.1.3 2022-07-05 [1] CRAN (R 4.3.1) #> ggplot2 * 3.4.4 2023-10-12 [1] CRAN (R 4.3.2) #> globals 0.16.2 2022-11-21 [1] CRAN (R 4.3.0) #> glue 1.6.2 2022-02-24 [1] CRAN (R 4.3.1) #> gower 1.0.1 2022-12-22 [1] CRAN (R 4.3.0) #> GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.3.2) #> gtable 0.3.4 2023-08-21 [1] CRAN (R 4.3.1) #> hardhat 1.3.0 2023-03-30 [1] CRAN (R 4.3.1) #> hms 1.1.3 2023-03-21 [1] CRAN (R 4.3.1) #> htmltools 0.5.6 2023-08-10 [1] CRAN (R 4.3.1) #> infer * 1.0.5 2023-09-06 [1] CRAN (R 4.3.2) #> inum 1.0-5 2023-03-09 [1] CRAN (R 4.3.1) #> ipred 0.9-14 2023-03-09 [1] CRAN (R 4.3.1) #> iterators 1.0.14 2022-02-05 [1] CRAN (R 4.3.1) #> kernelshap * 0.4.2 2024-04-16 [1] Github (ModelOriented/kernelshap@559ba5c) #> knitr 1.43 2023-05-25 [1] CRAN (R 4.3.1) #> lattice 0.22-5 2023-10-24 [2] CRAN (R 4.3.3) #> lava 1.7.2.1 2023-02-27 [1] CRAN (R 4.3.1) #> lhs 1.1.6 2022-12-17 [1] CRAN (R 4.3.2) #> libcoin 1.0-9 2021-09-27 [1] CRAN (R 4.3.1) #> lifecycle 1.0.4 2023-11-07 [1] CRAN (R 4.3.2) #> listenv 0.9.0 2022-12-16 [1] CRAN (R 4.3.1) #> lubridate * 1.9.2 2023-02-10 [1] CRAN (R 4.3.1) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.3.1) #> MASS 7.3-60.0.1 2024-01-13 [2] CRAN (R 4.3.3) #> Matrix 1.6-1 2023-08-14 [1] CRAN (R 4.3.1) #> matrixStats 1.0.0 2023-06-02 [1] CRAN (R 4.3.1) #> modeldata * 1.2.0 2023-08-09 [1] CRAN (R 4.3.2) #> modeltools 0.2-23 2020-03-05 [1] CRAN (R 4.3.0) #> multcomp 1.4-25 2023-06-20 [1] CRAN (R 4.3.1) #> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.3.1) #> mvtnorm 1.2-3 2023-08-25 [1] CRAN (R 4.3.1) #> nnet 7.3-19 2023-05-03 [2] CRAN (R 4.3.3) #> parallelly 1.36.0 2023-05-26 [1] CRAN (R 4.3.0) #> parsnip * 1.2.1 2024-03-22 [1] CRAN (R 4.3.3) #> partykit 1.2-20 2023-04-14 [1] CRAN (R 4.3.2) #> pillar 1.9.0 2023-03-22 [1] CRAN (R 4.3.1) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.3.1) #> prodlim 2023.03.31 2023-04-02 [1] CRAN (R 4.3.1) #> purrr * 1.0.2 2023-08-10 [1] CRAN (R 4.3.1) #> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.3.1) #> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.3.0) #> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.3.0) #> R.utils 2.12.2 2022-11-11 [1] CRAN (R 4.3.1) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.3.1) #> Rcpp 1.0.11 2023-07-06 [1] CRAN (R 4.3.1) #> readr * 2.1.4 2023-02-10 [1] CRAN (R 4.3.1) #> recipes * 1.0.10 2024-02-18 [1] CRAN (R 4.3.3) #> reprex * 2.0.2 2022-08-17 [1] CRAN (R 4.3.1) #> rlang 1.1.1 2023-04-28 [1] CRAN (R 4.3.1) #> rmarkdown 2.24 2023-08-14 [1] CRAN (R 4.3.1) #> rpart 4.1.23 2023-12-05 [2] CRAN (R 4.3.3) #> rsample * 1.2.0 2023-08-23 [1] CRAN (R 4.3.2) #> rstudioapi 0.15.0 2023-07-07 [1] CRAN (R 4.3.1) #> sandwich 3.0-2 2022-06-15 [1] CRAN (R 4.3.1) #> scales * 1.3.0 2023-11-28 [1] CRAN (R 4.3.2) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.3.2) #> stringi 1.7.12 2023-01-11 [1] CRAN (R 4.3.0) #> stringr * 1.5.0 2022-12-02 [1] CRAN (R 4.3.1) #> styler 1.10.1 2023-06-05 [1] CRAN (R 4.3.1) #> survival * 3.5-8 2024-02-14 [2] CRAN (R 4.3.3) #> TH.data 1.1-2 2023-04-17 [1] CRAN (R 4.3.1) #> tibble * 3.2.1 2023-03-20 [1] CRAN (R 4.3.1) #> tidymodels * 1.1.1 2023-08-24 [1] CRAN (R 4.3.2) #> tidyr * 1.3.0 2023-01-24 [1] CRAN (R 4.3.1) #> tidyselect 1.2.0 2022-10-10 [1] CRAN (R 4.3.1) #> tidyverse * 2.0.0 2023-02-22 [1] CRAN (R 4.3.3) #> timechange 0.2.0 2023-01-11 [1] CRAN (R 4.3.1) #> timeDate 4022.108 2023-01-07 [1] CRAN (R 4.3.0) #> tune * 1.2.0 2024-03-20 [1] CRAN (R 4.3.3) #> tzdb 0.4.0 2023-05-12 [1] CRAN (R 4.3.1) #> utf8 1.2.3 2023-01-31 [1] CRAN (R 4.3.1) #> vctrs 0.6.3 2023-06-14 [1] CRAN (R 4.3.1) #> withr 2.5.2 2023-10-30 [1] CRAN (R 4.3.2) #> workflows * 1.1.4 2024-02-19 [1] CRAN (R 4.3.3) #> workflowsets * 1.0.1 2023-04-06 [1] CRAN (R 4.3.2) #> xfun 0.40 2023-08-09 [1] CRAN (R 4.3.1) #> yaml 2.3.7 2023-01-23 [1] CRAN (R 4.3.0) #> yardstick * 1.3.1 2024-03-21 [1] CRAN (R 4.3.3) #> zoo 1.8-12 2023-04-13 [1] CRAN (R 4.3.1) #> #> [1] C:/Users/Shane/AppData/Local/R/win-library/4.3 #> [2] C:/Program Files/R/R-4.3.3/library #> #> ────────────────────────────────────────────────────────────────────────────── ```
mayer79 commented 6 months ago

Thanks for digging into this. Random forests and SHAP are a bit a weak spot...

Detail: I'd replace dplyr::select(.pred_survival) by pull(.pred_survival).

Generally, I feel that the prediction function is extremely slow, even for as few as 50 trees as in your example.

In non-parallel mode, kernelshap() eats 5 GB of data - most of the time is spent in predicting values and transforming the predictions:

image