ModelOriented / survex

Explainable Machine Learning in Survival Analysis
https://modeloriented.github.io/survex
GNU General Public License v3.0
96 stars 10 forks source link

Explainer for graph learners from mlr3proba #94

Closed Lee-xli closed 1 month ago

Lee-xli commented 1 month ago

Hi survex! Just wonder if there is any advice on building explainers from a graph learner by mlr3. This seems to be quite a silly question as graph learners are heavily used in preprocessing (as an example below, xgboost), not to mention other possible use cases of graph learners.

The mlr3proba example with KM composition seems to suggest that graph learners should work.

Have I missed anything obvious?

Many thanks!

library(survex)
library(survival)
library(mlr3proba)
#> Loading required package: mlr3
library(mlr3extralearners)
library(mlr3pipelines)
library(mlr3tuning)
#> Loading required package: paradox
vet <- survival::veteran
veteran_task <- as_task_surv(veteran,
                             time = "time",
                             event = "status",
                             type = "right")
xgb_basic = as_learner(
  po("encode") %>>%
    auto_tuner(tuner=tnr("grid_search", resolution = 2, batch_size =5),
               lrn("surv.xgboost",
                   eta = to_tune(0.001, 0.1)),
               rsmp("cv", folds = 5),
               measure = msr("surv.cindex"),
               terminator= trm("stagnation", iters=50)),
  store_models = T) 
#> Warning: 'surv.xgboost' will be deprecated in the future. Use 'surv.xgboost.cox'
#> or 'surv.xgboost.aft' learners instead.
xgb_basic$id = 'xgb'

xgb_basic$train(veteran_task)
#> INFO  [23:36:21.293] [bbotk] Starting to optimize 1 parameter(s) with '<OptimizerBatchGridSearch>' and '<TerminatorStagnation> [iters=50, threshold=0]'
#> INFO  [23:36:21.350] [bbotk] Evaluating 2 configuration(s)
#> INFO  [23:36:21.369] [mlr3] Running benchmark with 10 resampling iterations
#> INFO  [23:36:21.423] [mlr3] Applying learner 'surv.xgboost' on task 'veteran' (iter 1/5)
#> INFO  [23:36:21.460] [mlr3] Applying learner 'surv.xgboost' on task 'veteran' (iter 2/5)
#> INFO  [23:36:21.491] [mlr3] Applying learner 'surv.xgboost' on task 'veteran' (iter 3/5)
#> INFO  [23:36:21.522] [mlr3] Applying learner 'surv.xgboost' on task 'veteran' (iter 4/5)
#> INFO  [23:36:21.555] [mlr3] Applying learner 'surv.xgboost' on task 'veteran' (iter 5/5)
#> INFO  [23:36:21.590] [mlr3] Applying learner 'surv.xgboost' on task 'veteran' (iter 1/5)
#> INFO  [23:36:21.631] [mlr3] Applying learner 'surv.xgboost' on task 'veteran' (iter 2/5)
#> INFO  [23:36:21.666] [mlr3] Applying learner 'surv.xgboost' on task 'veteran' (iter 3/5)
#> INFO  [23:36:21.697] [mlr3] Applying learner 'surv.xgboost' on task 'veteran' (iter 4/5)
#> INFO  [23:36:21.729] [mlr3] Applying learner 'surv.xgboost' on task 'veteran' (iter 5/5)
#> INFO  [23:36:21.764] [mlr3] Finished benchmark
#> INFO  [23:36:21.821] [bbotk] Result of batch 1:
#> INFO  [23:36:21.825] [bbotk]    eta surv.cindex warnings errors runtime_learners
#> INFO  [23:36:21.825] [bbotk]  0.001   0.6511272        0      0            0.105
#> INFO  [23:36:21.825] [bbotk]  0.100   0.6511272        0      0            0.111
#> INFO  [23:36:21.825] [bbotk]                                 uhash
#> INFO  [23:36:21.825] [bbotk]  d56aa0d3-4df0-4c7c-a14f-79d99281fba1
#> INFO  [23:36:21.825] [bbotk]  063591f6-b839-4558-a97b-9bd4a34e97fd
#> INFO  [23:36:21.843] [bbotk] Finished optimizing after 2 evaluation(s)
#> INFO  [23:36:21.844] [bbotk] Result:
#> INFO  [23:36:21.847] [bbotk]    eta learner_param_vals  x_domain surv.cindex
#> INFO  [23:36:21.847] [bbotk]  <num>             <list>    <list>       <num>
#> INFO  [23:36:21.847] [bbotk]  0.001          <list[5]> <list[1]>   0.6511272

xgb_explainer <- explain(xgb_basic, 
                              data = veteran[, -c(3,4)],
                              y = Surv(veteran$time, veteran$status),
                              label = "xgb model",
                         type = 'survival')
#> Preparation of a new explainer is initiated
#>   -> model label       :  xgb model 
#>   -> data              :  137  rows  6  cols 
#>   -> target variable   :  137  values 
#>   -> predict function  :  yhat.default will be used (  default  )
#>   -> predicted values  :  No value for predict function target column. (  default  )
#>   -> model_info        :  package Model of class: GraphLearner package unrecognized , ver. Unknown , task regression (  default  ) 
#>   -> model_info        :  type set to  survival 
#>   -> predicted values  :  numerical, min =  -0.6938823 , mean =  -0.6925251 , max =  -0.6902816  
#>   -> residual function :  difference between y and yhat (  default  )
#>   -> residuals         :  the residual_function returns an error when executed (  WARNING  ) 
#>   A new explainer has been created!
exp_xgb <- model_survshap(xgb_explainer, veteran[c(1:4, 17:20, 110:113, 126:129), -c(3,4)], aggregation_method='mean_absolute')
#> Error in UseMethod("model_survshap", explainer): no applicable method for 'model_survshap' applied to an object of class "explainer"

Created on 2024-08-11 by the reprex package (v2.0.1)

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.1.1 (2021-08-10) #> os macOS Big Sur 10.16 #> system x86_64, darwin17.0 #> ui X11 #> language (EN) #> collate en_US.UTF-8 #> ctype en_US.UTF-8 #> tz Australia/Adelaide #> date 2024-08-11 #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date lib #> backports 1.5.0 2024-05-23 [1] #> bbotk 1.0.0 2024-06-28 [1] #> checkmate 2.3.1 2023-12-04 [1] #> cli 3.6.3 2024-06-21 [1] #> codetools 0.2-18 2020-11-04 [2] #> colorspace 2.1-0 2023-01-23 [1] #> crayon 1.4.1 2021-02-08 [2] #> DALEX 2.4.3 2023-01-15 [1] #> data.table 1.15.4 2024-03-30 [1] #> dictionar6 0.1.3 2021-09-13 [1] #> digest 0.6.36 2024-06-23 [1] #> distr6 1.8.4 2024-06-13 [1] #> dplyr 1.1.3 2023-09-03 [1] #> evaluate 0.24.0 2024-06-10 [1] #> fansi 1.0.6 2023-12-08 [1] #> fastmap 1.1.0 2021-01-25 [2] #> fs 1.5.0 2020-07-31 [2] #> future 1.33.2 2024-03-26 [1] #> future.apply 1.11.2 2024-03-28 [1] #> generics 0.1.3 2022-07-05 [1] #> ggplot2 3.5.1 2024-04-23 [1] #> globals 0.16.3 2024-03-08 [1] #> glue 1.7.0 2024-01-09 [1] #> gtable 0.3.5 2024-04-22 [1] #> highr 0.9 2021-04-16 [2] #> htmltools 0.5.6 2023-08-10 [1] #> jsonlite 1.7.2 2020-12-09 [2] #> knitr 1.33 2021-04-24 [2] #> lattice 0.20-44 2021-05-02 [2] #> lgr 0.4.4 2022-09-05 [1] #> lifecycle 1.0.4 2023-11-07 [1] #> listenv 0.9.1 2024-01-29 [1] #> magrittr 2.0.3 2022-03-30 [1] #> Matrix 1.3-4 2021-06-01 [2] #> mlr3 * 0.20.0 2024-06-28 [1] #> mlr3extralearners * 0.8.0-9000 2024-06-15 [1] #> mlr3misc 0.15.1 2024-06-24 [1] #> mlr3pipelines * 0.6.0 2024-07-16 [1] #> mlr3proba * 0.6.3 2024-06-13 [1] #> mlr3tuning * 1.0.0 2024-06-29 [1] #> mlr3viz 0.9.0 2024-07-01 [1] #> munsell 0.5.1 2024-04-01 [1] #> ooplah 0.2.0 2022-01-21 [1] #> palmerpenguins 0.1.1 2022-08-15 [1] #> paradox * 1.0.1 2024-07-09 [1] #> parallelly 1.37.1 2024-02-29 [1] #> param6 0.2.4 2023-11-22 [1] #> patchwork 1.2.0 2024-01-08 [1] #> pillar 1.9.0 2023-03-22 [1] #> pkgconfig 2.0.3 2019-09-22 [2] #> R6 2.5.1 2021-08-19 [1] #> Rcpp 1.0.12 2024-01-09 [1] #> reprex 2.0.1 2021-08-05 [1] #> RhpcBLASctl 0.23-42 2023-02-11 [1] #> rlang 1.1.4 2024-06-04 [1] #> rmarkdown 2.10 2021-08-06 [2] #> rstudioapi 0.15.0 2023-07-07 [1] #> scales 1.3.0 2023-11-28 [1] #> sessioninfo 1.1.1 2018-11-05 [2] #> set6 0.2.6 2023-11-22 [1] #> stringi 1.7.3 2021-07-16 [2] #> stringr 1.5.0 2022-12-02 [1] #> survex * 1.2.0 2023-10-24 [1] #> survival * 3.7-0 2024-06-05 [1] #> tibble 3.2.1 2023-03-20 [1] #> tidyselect 1.2.0 2022-10-10 [1] #> utf8 1.2.4 2023-10-22 [1] #> uuid 1.2-0 2024-01-14 [1] #> vctrs 0.6.5 2023-12-01 [1] #> withr 3.0.0 2024-01-16 [1] #> xfun 0.25 2021-08-06 [2] #> xgboost 1.4.1.1 2021-04-22 [2] #> yaml 2.2.1 2020-02-01 [2] #> source #> CRAN (R 4.1.1) #> CRAN (R 4.1.1) #> CRAN (R 4.1.1) #> CRAN (R 4.1.1) #> CRAN (R 4.1.1) #> CRAN (R 4.1.2) #> CRAN (R 4.1.0) #> CRAN (R 4.1.2) #> CRAN (R 4.1.1) #> CRAN (R 4.1.0) #> CRAN (R 4.1.1) #> Github (xoopR/distr6@95d7359) #> CRAN (R 4.1.1) #> CRAN (R 4.1.1) #> CRAN (R 4.1.1) #> CRAN (R 4.1.0) #> CRAN (R 4.1.0) #> CRAN (R 4.1.1) #> CRAN (R 4.1.1) #> CRAN (R 4.1.2) #> CRAN (R 4.1.1) #> CRAN (R 4.1.1) #> CRAN (R 4.1.1) #> CRAN (R 4.1.1) #> CRAN (R 4.1.0) #> CRAN (R 4.1.1) #> CRAN (R 4.1.0) #> CRAN (R 4.1.0) #> CRAN (R 4.1.1) #> CRAN (R 4.1.2) #> CRAN (R 4.1.1) #> CRAN (R 4.1.1) #> CRAN (R 4.1.2) #> CRAN (R 4.1.1) #> CRAN (R 4.1.1) #> Github (mlr-org/mlr3extralearners@6dc6965) #> CRAN (R 4.1.1) #> Github (mlr-org/mlr3pipelines@c542a26) #> Github (mlr-org/mlr3proba@5205752) #> CRAN (R 4.1.1) #> CRAN (R 4.1.1) #> CRAN (R 4.1.1) #> CRAN (R 4.1.2) #> CRAN (R 4.1.2) #> CRAN (R 4.1.1) #> CRAN (R 4.1.1) #> Github (xoopR/param6@0fa3577) #> CRAN (R 4.1.1) #> CRAN (R 4.1.2) #> CRAN (R 4.1.0) #> CRAN (R 4.1.0) #> CRAN (R 4.1.1) #> CRAN (R 4.1.0) #> CRAN (R 4.1.2) #> CRAN (R 4.1.1) #> CRAN (R 4.1.0) #> CRAN (R 4.1.1) #> CRAN (R 4.1.1) #> CRAN (R 4.1.0) #> Github (xoopR/set6@a901255) #> CRAN (R 4.1.0) #> CRAN (R 4.1.2) #> CRAN (R 4.1.1) #> CRAN (R 4.1.1) #> CRAN (R 4.1.2) #> CRAN (R 4.1.2) #> CRAN (R 4.1.1) #> CRAN (R 4.1.1) #> CRAN (R 4.1.1) #> CRAN (R 4.1.1) #> CRAN (R 4.1.0) #> CRAN (R 4.1.0) #> CRAN (R 4.1.0) #> #> [1] /Users/Lee/Library/R/x86_64/4.1/library #> [2] /Library/Frameworks/R.framework/Versions/4.1/Resources/library ```
Lee-xli commented 1 month ago

Please ignore and close this as I realised that the learner class needs to be specified as described in the mlr3 example: https://modeloriented.github.io/survex/articles/mlr3proba-usage.html My apologies if this has taken some of your time unnecessarily!