mlr-org / mlr3pipelines

Dataflow Programming for Machine Learning in R
https://mlr3pipelines.mlr-org.com/
GNU Lesser General Public License v3.0
137 stars 25 forks source link

explain_mlr3 fails on ensemble model #642

Closed Tato14 closed 2 years ago

Tato14 commented 2 years ago

I created an ensemble model like:

library(mlr3)
library(mlr3learners)
library(mlr3tuning)
library(mlr3pipelines)
library(mlr3verse)
library(mlr3extralearners)
library("DALEX")
library("DALEXtra")
library(tidyverse)
set.seed(1)
##Recreate dataset
data("penguins", package = "palmerpenguins")
str(penguins)
penguins = na.omit(penguins)
lrn = lrn("classif.xgboost", nrounds = 100, predict_type = "prob")
task_peng = as_task_classif(penguins, target = "species")
ord_to_int = po("colapply", applicator = as.integer,
                affect_columns = selector_type("ordered"))
p0 = ppl("robustify", task_peng, lrn) %>>%  ord_to_int %>>% po("imputeoor")
gr_fil = p0
graph_data <- gr_fil$train(task_peng)
peng_task <- graph_data[[1]]$data()
task_peng_ready = as_task_classif(peng_task, target = "species")

##Recreate ensemble
myvar <- list()
learners_all = as.data.table(list_mlr3learners()) #select = c("id", "mlr3_package", "required_packages"))
learners_to_try <- learners_all %>% filter(class == "classif") %>% filter(grepl("weights", properties)) %>% filter(grepl("integer", feature_types))
learners_to_try <- learners_to_try %>% dplyr::select(c(name, id))
learners_to_try <- as.data.frame(learners_to_try)

for (i in learners_to_try$name[c(5,20)]) { myvar[[i]]  <-  po("learner_cv", learner = lrn(learners_to_try[learners_to_try$name == i,][[2]],  id = i, predict_type = "prob" ))}
myvar$ranger$param_set$values <- list(mtry = 1, num.trees=1571, sample.fraction=0.3242656, min.node.size=70,
                                      resampling.method = "cv", resampling.folds = 3,resampling.keep_response =FALSE,
                                      num.threads = 1)

graph_stack = gunion(myvar) %>>%
  po("featureunion") 

lrn_avg = LearnerClassifAvg$new( id = "classif.avg") 
lrn_avg$predict_type ="prob"

graph_stack_avg <- graph_stack %>>% lrn_avg

graph_stack_avg$train(task_peng_ready)
graph_stack_avg$pipeops 

##Recreate mlr3_explain
y= peng_task$species
y= factor(y, levels = c("Adelie", "Chinstrap", "Gentoo"), labels = c(0,1,2))
y=as.numeric(as.character(y))
penguins_2 <- peng_task[,-1]

ensemble_exp <- explain_mlr3(graph_stack_avg,
                             data     = penguins_2,
                             y        = y,
                             label    = "Ensemble",
                             colorize = FALSE)

explain_ens <- model_parts(ensemble_exp, B=10)

This gives the following error:

Error in check_item(data[[idx]], typetable[[operation]][[idx]], varname = sprintf("%s %s (\"%s\") of PipeOp %s's $%s()",  : 
  Assertion on 'input 1 ("input") of PipeOp cv_glmnet.cv_glmnet's $predict()' failed: Must inherit from class 'TaskClassif', but has classes 'TaskRegr','TaskSupervised','Task','R6'.

I am unsure how to solve this, any idea? Thanks

hbaniecki commented 2 years ago

Hi, trying to run your example I get:

image

R version 4.1.1 (2021-08-10)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19044)

Matrix products: default

locale:
[1] LC_COLLATE=English_United Kingdom.1252  LC_CTYPE=English_United Kingdom.1252   
[3] LC_MONETARY=English_United Kingdom.1252 LC_NUMERIC=C                           
[5] LC_TIME=English_United Kingdom.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] forcats_0.5.1            stringr_1.4.0            dplyr_1.0.7             
 [4] purrr_0.3.4              readr_2.0.2              tidyr_1.1.4             
 [7] tibble_3.1.6             ggplot2_3.3.5            tidyverse_1.3.1         
[10] DALEXtra_2.1.1           DALEX_2.3.0              mlr3extralearners_0.5.18
[13] mlr3verse_0.2.2          mlr3pipelines_0.4.0      mlr3tuning_0.9.0        
[16] paradox_0.7.1            mlr3learners_0.5.1       mlr3_0.13.0  
Tato14 commented 2 years ago

@hbaniecki I am unsure why you cannot reproduce the code. In any case, it seems that the error is that I missed to encapsulate the graph as a learner with GraphLearner$new(graph) command. This seems to solve the issue.

mllg commented 2 years ago

Transfered the issue, maybe we can improve the error messages here or check if we can do an auto-conversion.

mb706 commented 2 years ago

@hbaniecki the problem here is probably that the value of myvar depends on the result of list_mlr3learners(), which in turn depends on your installed version of mlr3extralearners (and to some degree on the versions of mlr3 and mlr3learners that were installed on the machine where your mlr3extralearners package was built).

You probably have a different myvar than Tato14 has.

mb706 commented 2 years ago

One problem is in predict.Graph: we chose to emulate a lightweight Learner here, but we always assume regression. Instead we should assume whatever the Graph needs. However, ideally DALEXtra should auto-convert or at least check that its argument in explain_mlr3 is a Learner (https://github.com/ModelOriented/DALEXtra/issues/71).