mlr-org / mlr3

mlr3: Machine Learning in R - next generation
https://mlr3.mlr-org.com
GNU Lesser General Public License v3.0
927 stars 86 forks source link

Error for bootstrap resampling on stacked/ensemble learner #903

Open h-a-graham opened 1 year ago

h-a-graham commented 1 year ago

Description

So I'm trying to generate bootstrapped resamples for an ensemble model but it throws an error. This seems to result from the duplication of row_ids; I suppose these duplicate rows should be expected due to the resampling with replacement but I'm not sure why this fails in the ensemble context but not with a single learner. Posting here after previously asking a question on stackoverflow

Reproducible example

library(mlr3)
library(mlr3learners)
library(mlr3pipelines)
library(progressr)
lgr::get_logger("mlr3")$set_threshold("warn")

ens.lrnr <- gunion(list(
  po("learner_cv",lrn("regr.svm")), 
  po("learner_cv",lrn("regr.rpart")))) %>>%
  po("featureunion") %>>% 
  lrn("regr.lm", id="master") |> 
  as_learner()

task <- tsk("boston_housing")
task$select(task$feature_names[! task$feature_names %in% c("town", "chas")])

boot_res <- function(.lrnr) {
  progressr::with_progress(expr = {
    mlr3::resample(
      task = task,
      learner = .lrnr,
      resampling =  rsmp("bootstrap", repeats = 100, ratio = 1), 
      store_models = FALSE
    )
  })
}

# single learner works
rpart_boot <- boot_res(lrn("regr.rpart"))

#ensemble learner fails
ens_boot <- boot_res(ens.lrnr)
#> Error in as_data_backend.data.frame(data, primary_key = row_ids): Assertion on 'data[[primary_key]]' failed: Contains duplicated values, position 7.
#> This happened PipeOp regr.svm's $train()

Created on 2023-02-17 with reprex v2.0.2

h-a-graham commented 1 year ago

Update to this - The following appears to give me what I need but I presume that the original reprex should in principle still work?

library(mlr3)
library(mlr3learners)
library(mlr3pipelines)
library(progressr)
lgr::get_logger("mlr3")$set_threshold("warn")

ens.lrnr <- gunion(list(
  po("learner_cv",lrn("regr.svm")), 
  po("learner_cv",lrn("regr.rpart")))) %>>%
  po("featureunion") %>>% 
  lrn("regr.lm", id="master") |> 
  as_learner()

task <- tsk("boston_housing")
task$select(task$feature_names[! task$feature_names %in% c("town", "chas")])

boot_res <- function(.lrnr, n=10){
  ens.boot <- po("subsample", param_vals = list(frac = 1, replace = TRUE)) %>>% 
    .lrnr

  g_rep <- pipeline_greplicate(ens.boot, n = n)

  g_rep$train(task)

  g_rep$predict(task)
}

#ensemble learner fails
ens_boot <- boot_res(ens.lrnr)
print(ens_boot[1:2])
#> $regr.svm.regr.rpart.featureunion.master_1.output
#> <PredictionRegr> for 506 observations:
#>     row_ids truth response
#>           1  24.0 25.14968
#>           2  21.6 21.54997
#>           3  34.7 34.39916
#> ---                       
#>         504  23.9 24.12859
#>         505  22.0 21.06487
#>         506  11.9 16.49974
#> 
#> $regr.svm.regr.rpart.featureunion.master_2.output
#> <PredictionRegr> for 506 observations:
#>     row_ids truth response
#>           1  24.0 24.30699
#>           2  21.6 23.05628
#>           3  34.7 35.69923
#> ---                       
#>         504  23.9 24.25398
#>         505  22.0 23.30728
#>         506  11.9 19.32718

Created on 2023-02-20 with reprex v2.0.2