mlr-org / mlr3learners

Recommended learners for mlr3
https://mlr3learners.mlr-org.com
GNU Lesser General Public License v3.0
89 stars 14 forks source link

`predcontrib = TRUE` in xgboost classification learner causes errors #236

Open TylerGrantSmith opened 2 years ago

TylerGrantSmith commented 2 years ago

Using predcontrib = TRUE to get Shapley values with the xgboost classifier learner causes errors because of mlr3 expecting a different prediction format.

With predcontrib = TRUE xgboost returns a matrix and the default behavior of mlr3learners:::LearnerClassifXgboost$private_methods$.train flattens the matrix which causes a # of rows mismatch.

How can I get these Shapley contributions using xgboost's built-in method?

library(mlr3)
#> Warning: package 'mlr3' was built under R version 4.0.5
#> Registered S3 methods overwritten by 'parallelly':
#>   method                     from  
#>   c.cluster                  future
#>   print.RichSOCKcluster      future
#>   stopCluster.RichMPIcluster future
#>   summary.RichSOCKcluster    future
#>   summary.RichSOCKnode       future
library(mlr3learners)
#> Warning: package 'mlr3learners' was built under R version 4.0.5
library(mlr3pipelines)
#> Warning: package 'mlr3pipelines' was built under R version 4.0.5

penguins = palmerpenguins::penguins
penguins <- penguins[!is.na(penguins$sex), ]
task = as_task_classif(penguins, target = "sex", positive = "male")

learner = lrn("classif.xgboost", 
               predict_type = 'prob',
               predcontrib = TRUE, 
               nrounds = 10)

fencoder = po("encode", 
              method = "treatment", 
              affect_columns = selector_type("factor"))

graph = fencoder %>>% learner
graph_learner = as_learner(graph)

graph_learner$train(task)
pred <- graph_learner$predict(task)
#> Error: Predicted prob contains 2997 additional predictions without matching rows
#> This happened PipeOp classif.xgboost's $predict()