giuseppec / iml

iml: interpretable machine learning R package
https://giuseppec.github.io/iml/
Other
491 stars 88 forks source link

Problems with caret sbf objects #165

Open nebfield opened 3 years ago

nebfield commented 3 years ago

Minimal dataset

library(caret)
library(iml)
dat <- twoClassSim(100)
X <- dat[,1:5]
X$y <- dat[["Class"]]

Minimal, runnable code

tr <- sbf(
  y ~ .,
  data = X,
  sbfControl = sbfControl(
    functions = caretSBF,
    verbose = FALSE,
    method = "cv",
    number = 10
  ),
  trControl = trainControl(classProbs = TRUE),
  method = "svmLinear"
)

Problems start

Predictor$new(tr, X, y = "y")

Prediction task:unknown

pred <- Predictor$new(tr, X, y = "y")
Shapley$new(pred, X, X[1, ])

Error in colMeans(self$predictor$predict(private$sampler$get.x()): x must be numeric

These problems don't happen if I create Predictor with a fit object extracted from dbf (Predictor$new(tr$fit, ...) but this is a bad idea.

Thanks for making iml. Your book is amazing :grin:

Session info


R version 4.0.2 (2020-06-22)
Platform: x86_64-apple-darwin17.0 (64-bit)
Running under: macOS  10.16

Matrix products: default
LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib

locale:
[1] en_GB.UTF-8/en_GB.UTF-8/en_GB.UTF-8/C/en_GB.UTF-8/en_GB.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] iml_0.10.1      caret_6.0-86    ggplot2_3.3.2   lattice_0.20-41

loaded via a namespace (and not attached):
 [1] Metrics_0.1.4        statmod_1.4.35       tidyselect_1.1.0    
 [4] kernlab_0.9-29       listenv_0.8.0        purrr_0.3.4         
 [7] reshape2_1.4.4       splines_4.0.2        colorspace_1.4-1    
[10] vctrs_0.3.4          generics_0.1.0       stats4_4.0.2        
[13] survival_3.2-7       prodlim_2019.11.13   rlang_0.4.8         
[16] e1071_1.7-4          ModelMetrics_1.2.2.2 nloptr_1.2.2.2      
[19] pillar_1.4.6         glue_1.4.2           withr_2.3.0         
[22] foreach_1.5.1        lifecycle_0.2.0      plyr_1.8.6          
[25] lava_1.6.8           stringr_1.4.0        timeDate_3043.102   
[28] munsell_0.5.0        gtable_0.3.0         prediction_0.3.14   
[31] future_1.20.1        recipes_0.1.14       codetools_0.2-16    
[34] parallel_4.0.2       class_7.3-17         Rcpp_1.0.5          
[37] backports_1.1.10     checkmate_2.0.0      scales_1.1.1        
[40] ipred_0.9-9          parallelly_1.21.0    lme4_1.1-25         
[43] digest_0.6.27        stringi_1.5.3        dplyr_1.0.2         
[46] grid_4.0.2           tools_4.0.2          magrittr_1.5        
[49] tibble_3.0.4         crayon_1.3.4         tidyr_1.1.2         
[52] pkgconfig_2.0.3      MASS_7.3-53          ellipsis_0.3.1      
[55] Matrix_1.2-18        data.table_1.13.2    pROC_1.16.2         
[58] lubridate_1.7.9      gower_0.2.2          minqa_1.2.4         
[61] rstudioapi_0.11      iterators_1.0.13     globals_0.13.1      
[64] R6_2.5.0             boot_1.3-25          rpart_4.1-15        
[67] nnet_7.3-14          nlme_3.1-150         compiler_4.0.2