giuseppec / iml

iml: interpretable machine learning R package
https://giuseppec.github.io/iml/
Other
492 stars 87 forks source link

Shapley does not allow data.table inputs #217

Closed dandls closed 1 month ago

dandls commented 1 month ago

When I use data.table(x.interest) instead of x.interest in the example code of iml::Shapley, I receive an error message, i.e.:

library("iml")
library("data.table")
library("rpart")
# First we fit a machine learning model on the Boston housing data
data("Boston", package = "MASS")
rf <- rpart(medv ~ ., data = Boston)
X <- Boston[-which(names(Boston) == "medv")]
mod <- Predictor$new(rf, data = X)

# Then we explain the first instance of the dataset with the Shapley method:
x.interest <- X[1, ]
shapley <- Shapley$new(mod, x.interest = data.table(x.interest))

The error message is:

Error in `[.data.table`(x.interest, setdiff(colnames(x.interest), predictor$data$y.names)) : 
  When i is a data.table (or character vector), the columns to join by must be specified using 'on=' argument (see ?data.table), by keying x (i.e. sorted, and, marked as sorted, see ?setkey), or by sharing column names between x and i (i.e., a natural join). Keyed joins might have further speed benefits on very large data due to x being sorted in RAM.

I used the latest installations of packages on CRAN. This is my sessionInfo() output

> sessionInfo()
R version 4.4.0 (2024-04-24)
Platform: x86_64-pc-linux-gnu
Running under: Debian GNU/Linux 12 (bookworm)

Matrix products: default
BLAS:   /usr/local/lib/R/lib/libRblas.so 
LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.11.0

locale:
 [1] LC_CTYPE=en_GB.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=en_GB.UTF-8        LC_COLLATE=en_GB.UTF-8    
 [5] LC_MONETARY=en_GB.UTF-8    LC_MESSAGES=en_GB.UTF-8   
 [7] LC_PAPER=en_GB.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_GB.UTF-8 LC_IDENTIFICATION=C       

time zone: Europe/Berlin
tzcode source: system (glibc)

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] data.table_1.16.2 rpart_4.1.23      iml_0.11.3       

loaded via a namespace (and not attached):
 [1] vctrs_0.6.5       cli_3.6.3         rlang_1.1.4       Formula_1.2-5    
 [5] generics_0.1.3    glue_1.7.0        colorspace_2.1-0  listenv_0.9.1    
 [9] backports_1.5.0   Metrics_0.1.4     scales_1.3.0      fansi_1.0.6      
[13] grid_4.4.0        munsell_0.5.1     tibble_3.2.1      lifecycle_1.0.4  
[17] compiler_4.4.0    dplyr_1.1.4       codetools_0.2-20  pkgconfig_2.0.3  
[21] future_1.34.0     digest_0.6.37     R6_2.5.1          tidyselect_1.2.1 
[25] utf8_1.2.4        pillar_1.9.0      parallelly_1.38.0 parallel_4.4.0   
[29] magrittr_2.0.3    checkmate_2.3.2   tools_4.4.0       gtable_0.3.5     
[33] globals_0.16.3    ggplot2_3.5.1