mlr-org / mlr3extralearners

Extra learners for use in mlr3.
https://mlr3extralearners.mlr-org.com/
89 stars 48 forks source link

`surv.aorsf`: Workaround for `leaf_min_events = X should be <= Y (number of events divided by 2)` in tuning setting? #383

Open jemus42 opened 2 weeks ago

jemus42 commented 2 weeks ago

Description

In surv.aorsf, when the leaf_min_events parameter is tuned, then allowed values depend on the number of events in the respective task (or subset of the task used for training). This leads to some errors in our benchmark where we tune leaf_min_events in the range of 5 through 50, but for tasks with few observations we encounter the error message above due to resampling, e.g. leaf_min_events = 25 should be <= 20 (number of events divided by 2).

In practice we encapsulate the learner and use a fallback (KM) to impute results, but of course it would be better to not even attempt to evaluate "invalid" hyperparameter configurations.

A common example for a data-dependent parameter is mtry, for which we have introduced the mtry.ratio proxy parameter to tune mtry on a scale from 0 to 1/nfeatures rather than 1 to nfeatures. I am now wondering if it makes sense to introduce a similar proxy parameter here, or if we can get away with an .extra_trafo of some sorts (but I don't think .extra_trafo has access to the necessary information?).

@bcjaeger, if you have any insights here let use know!

Below is a reprex with an AutoTuner setup reproducing the error message, with the tuning spaces copied verbatim from our benchmark.

Reproducible example

library(mlr3)
library(mlr3proba)
library(mlr3extralearners)
library(mlr3pipelines)
library(mlr3tuning)
#> Loading required package: paradox

set.seed(123)
task = tsk("lung")
task$set_col_roles("status", add_to = "stratum")
# Creating a deliberately small training set
splits = partition(task, ratio = 0.5)

lrn_base = lrn("surv.aorsf", n_tree = 1000, control_type = "fast")

# Making it a graph is not strictly necessary but closer to my real code
lrn_graph = po("fixfactors") %>>%
  po("imputesample", affect_columns = selector_type("factor")) %>>%
  po("removeconstants") %>>%
  lrn_base |>
  as_learner()

lrn_auto = auto_tuner(
  learner = lrn_graph,
  search_space = ps(
    surv.aorsf.mtry_ratio = p_dbl(0, 1),
    surv.aorsf.leaf_min_events = p_int(5, 50),
    .extra_trafo = function(x, param_set) {
      x$surv.aorsf.split_min_obs = x$surv.aorsf.leaf_min_events + 5L
      x
    }
  ),
  resampling = rsmp("repeated_cv", folds = 3, repeats = 2),
  measure = msr("surv.cindex"),
  terminator = trm("evals", n_evals = 20, k = 0),
  tuner = tnr("grid_search", resolution = 10),
  store_models = TRUE,
  store_benchmark_result = TRUE,
  store_tuning_instance = TRUE
)

lrn_auto$train(task, row_ids = splits$train)
#> INFO  [16:57:34.387] [bbotk] Starting to optimize 2 parameter(s) with '<OptimizerBatchGridSearch>' and '<TerminatorEvals> [n_evals=20, k=0]'
#> INFO  [16:57:34.406] [bbotk] Evaluating 1 configuration(s)
#> INFO  [16:57:34.430] [mlr3] Running benchmark with 6 resampling iterations
#> INFO  [16:57:34.455] [mlr3] Applying learner 'fixfactors.imputesample.removeconstants.surv.aorsf' on task 'lung' (iter 1/6)
#> INFO  [16:57:34.492] [mlr3] Applying learner 'fixfactors.imputesample.removeconstants.surv.aorsf' on task 'lung' (iter 2/6)
#> INFO  [16:57:34.558] [mlr3] Applying learner 'fixfactors.imputesample.removeconstants.surv.aorsf' on task 'lung' (iter 3/6)
#> INFO  [16:57:34.585] [mlr3] Applying learner 'fixfactors.imputesample.removeconstants.surv.aorsf' on task 'lung' (iter 4/6)
#> INFO  [16:57:34.714] [mlr3] Applying learner 'fixfactors.imputesample.removeconstants.surv.aorsf' on task 'lung' (iter 5/6)
#> INFO  [16:57:34.739] [mlr3] Applying learner 'fixfactors.imputesample.removeconstants.surv.aorsf' on task 'lung' (iter 6/6)
#> Error: leaf_min_events = 45 should be <= 20 (number of events divided by 2)
#> This happened PipeOp surv.aorsf's $train()

Created on 2024-09-09 with reprex v2.1.1

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.4.1 (2024-06-14) #> os macOS Sonoma 14.6.1 #> system aarch64, darwin20 #> ui X11 #> language (EN) #> collate en_US.UTF-8 #> ctype en_US.UTF-8 #> tz Europe/Berlin #> date 2024-09-09 #> pandoc 3.2 @ /System/Volumes/Data/Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/aarch64/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date (UTC) lib source #> aorsf 0.1.5 2024-05-30 [1] CRAN (R 4.4.0) #> backports 1.5.0 2024-05-23 [1] CRAN (R 4.4.0) #> bbotk 1.0.1 2024-07-24 [1] CRAN (R 4.4.0) #> checkmate 2.3.2 2024-07-29 [1] CRAN (R 4.4.0) #> cli 3.6.3 2024-06-21 [1] CRAN (R 4.4.0) #> codetools 0.2-20 2024-03-31 [2] CRAN (R 4.4.1) #> collapse 2.0.16 2024-08-21 [1] CRAN (R 4.4.1) #> colorspace 2.1-1 2024-07-26 [1] CRAN (R 4.4.0) #> crayon 1.5.3 2024-06-20 [1] CRAN (R 4.4.0) #> data.table 1.16.0 2024-08-27 [1] CRAN (R 4.4.1) #> dictionar6 0.1.3 2021-09-13 [1] CRAN (R 4.4.0) #> digest 0.6.37 2024-08-19 [1] CRAN (R 4.4.1) #> distr6 1.8.4 2024-07-11 [1] Github (xoopR/distr6@a642cd3) #> dplyr 1.1.4 2023-11-17 [1] CRAN (R 4.4.0) #> evaluate 0.24.0 2024-06-10 [1] CRAN (R 4.4.0) #> fansi 1.0.6 2023-12-08 [1] CRAN (R 4.4.0) #> fastmap 1.2.0 2024-05-15 [1] CRAN (R 4.4.0) #> fs 1.6.4 2024-04-25 [1] CRAN (R 4.4.0) #> future 1.34.0 2024-07-29 [1] CRAN (R 4.4.0) #> future.apply 1.11.2 2024-03-28 [1] CRAN (R 4.4.0) #> generics 0.1.3 2022-07-05 [1] CRAN (R 4.4.0) #> ggplot2 3.5.1 2024-04-23 [1] CRAN (R 4.4.0) #> globals 0.16.3 2024-03-08 [1] CRAN (R 4.4.0) #> glue 1.7.0 2024-01-09 [1] CRAN (R 4.4.0) #> gtable 0.3.5 2024-04-22 [1] CRAN (R 4.4.0) #> htmltools 0.5.8.1 2024-04-04 [1] CRAN (R 4.4.0) #> knitr 1.48 2024-07-07 [1] CRAN (R 4.4.0) #> lattice 0.22-6 2024-03-20 [2] CRAN (R 4.4.1) #> lgr 0.4.4 2022-09-05 [1] CRAN (R 4.4.0) #> lifecycle 1.0.4 2023-11-07 [1] CRAN (R 4.4.0) #> listenv 0.9.1 2024-01-29 [1] CRAN (R 4.4.0) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.4.0) #> Matrix 1.7-0 2024-04-26 [2] CRAN (R 4.4.1) #> mlr3 * 0.20.2.9000 2024-09-08 [1] Github (mlr-org/mlr3@9539eff) #> mlr3extralearners * 0.9.0-9000 2024-08-24 [1] Github (mlr-org/mlr3extralearners@98d790f) #> mlr3misc 0.15.1 2024-06-24 [1] CRAN (R 4.4.0) #> mlr3pipelines * 0.6.0 2024-07-01 [1] CRAN (R 4.4.0) #> mlr3proba * 0.6.8 2024-09-08 [1] Github (mlr-org/mlr3proba@0e5c80b) #> mlr3tuning * 1.0.0 2024-06-29 [1] CRAN (R 4.4.0) #> mlr3viz 0.9.0.9000 2024-08-15 [1] Github (mlr-org/mlr3viz@db6e547) #> munsell 0.5.1 2024-04-01 [1] CRAN (R 4.4.0) #> ooplah 0.2.0 2022-01-21 [1] CRAN (R 4.4.0) #> palmerpenguins 0.1.1 2022-08-15 [1] CRAN (R 4.4.0) #> paradox * 1.0.1 2024-07-09 [1] CRAN (R 4.4.0) #> parallelly 1.38.0 2024-07-27 [1] CRAN (R 4.4.0) #> param6 0.2.4 2024-04-26 [1] Github (xoopR/param6@0fa3577) #> pillar 1.9.0 2023-03-22 [1] CRAN (R 4.4.0) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.4.0) #> pracma 2.4.4 2023-11-10 [1] CRAN (R 4.4.0) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.4.0) #> Rcpp 1.0.13 2024-07-17 [1] CRAN (R 4.4.0) #> reprex 2.1.1 2024-07-06 [1] CRAN (R 4.4.0) #> rlang 1.1.4 2024-06-04 [1] CRAN (R 4.4.0) #> rmarkdown 2.28 2024-08-17 [1] CRAN (R 4.4.0) #> rstudioapi 0.16.0 2024-03-24 [1] CRAN (R 4.4.0) #> scales 1.3.0 2023-11-28 [1] CRAN (R 4.4.0) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.4.0) #> set6 0.2.6 2024-04-26 [1] Github (xoopR/set6@a901255) #> survival 3.7-0 2024-06-05 [1] CRAN (R 4.4.0) #> tibble 3.2.1 2023-03-20 [1] CRAN (R 4.4.0) #> tidyselect 1.2.1 2024-03-11 [1] CRAN (R 4.4.0) #> utf8 1.2.4 2023-10-22 [1] CRAN (R 4.4.0) #> uuid 1.2-1 2024-07-29 [1] CRAN (R 4.4.0) #> vctrs 0.6.5 2023-12-01 [1] CRAN (R 4.4.0) #> withr 3.0.1 2024-07-31 [1] CRAN (R 4.4.0) #> xfun 0.47 2024-08-17 [1] CRAN (R 4.4.0) #> yaml 2.3.10 2024-07-26 [1] CRAN (R 4.4.0) #> #> [1] /Users/Lukas/Library/R/arm64/4.4/library #> [2] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library #> #> ────────────────────────────────────────────────────────────────────────────── ```
bcjaeger commented 2 weeks ago

I am now wondering if it makes sense to introduce a similar proxy parameter here

I like the idea of tuning leaf_min_events as a ratio (similar to mtry.ratio). Would aorsf be the only learner that would benefit from that?

jemus42 commented 2 weeks ago

Looking over rfsrc and ranger, I don't think they have parameters equivalent to leaf_min_events, so I guess this would then be aorsf-specific?

Searching for the convert_ratio helper which is used to make mtry_ratio happen I can see that there's at only one other parameter sampsize.ratio in surv.rfsrc using it, so I guess there's precedent for other parameters than mtry to be tuned as a ratio, but it appears to be uncommon. I'll give this a go in a PR and then we can see whether it seems like a good idea.