mlr-org / mlr3extralearners

Extra learners for use in mlr3.
https://mlr3extralearners.mlr-org.com/
90 stars 50 forks source link

`surv.aorsf`: `mtry_ratio` can result in `mtry = 1`, which causes an error #259

Closed jemus42 closed 1 year ago

jemus42 commented 1 year ago

Description

mtry_ratio can take valid values that, depending on the task, result in invalid values for mtry.

Maybe less a bug than likely undesired behavior where I'm not sure if this can/should be addressed in the learner or left as an exercise to the user, so to speak.

When mtry_ratio has an implicit lower bound not equal to 0, this can result in unintended behavior in benchmark scenarios where feature counts may vary widely across tasks, but mtry_ratio is tuned within [0, 1].

I'm not sure if there's anything that could be done learner-wise without causing other issues tho :/

That being said, I guess at least the mtry lower bound should be increased, as it's currently set to 1: https://github.com/mlr-org/mlr3extralearners/blob/5c7780e2c8074a530668ce8194bcf4ae8917099e/R/learner_aorsf_surv_aorsf.R#L40

Reproducible example

library(mlr3)
library(mlr3proba)
library(mlr3extralearners)

testlrn = lrn("surv.aorsf",  n_tree = 50, control_type = "fast", mtry_ratio = 0.1)

# task "whas" has 9 features, mtry_ratio = 0.1 leads to mtry = 1
testlrn$train(tsk("whas"))
#> Error: mtry = 1 should be >= 2

Created on 2022-12-07 with reprex v2.0.2

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.2.2 (2022-10-31) #> os macOS Ventura 13.0.1 #> system aarch64, darwin20 #> ui X11 #> language (EN) #> collate en_US.UTF-8 #> ctype en_US.UTF-8 #> tz Europe/Berlin #> date 2022-12-07 #> pandoc 2.19.2 @ /System/Volumes/Data/Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date (UTC) lib source #> aorsf 0.0.4 2022-11-07 [1] CRAN (R 4.2.1) #> assertthat 0.2.1 2019-03-21 [1] CRAN (R 4.2.0) #> backports 1.4.1 2021-12-13 [1] CRAN (R 4.2.0) #> checkmate 2.1.0 2022-04-21 [1] CRAN (R 4.2.0) #> cli 3.4.1 2022-09-23 [1] CRAN (R 4.2.0) #> codetools 0.2-18 2020-11-04 [1] CRAN (R 4.2.0) #> collapse 1.8.9 2022-10-07 [1] CRAN (R 4.2.0) #> colorspace 2.0-3 2022-02-21 [1] CRAN (R 4.2.0) #> crayon 1.5.2 2022-09-29 [1] CRAN (R 4.2.0) #> data.table 1.14.6 2022-11-16 [1] CRAN (R 4.2.1) #> DBI 1.1.3 2022-06-18 [1] CRAN (R 4.2.0) #> dictionar6 0.1.3 2021-09-13 [1] CRAN (R 4.2.0) #> digest 0.6.30 2022-10-18 [1] CRAN (R 4.2.0) #> distr6 1.6.11 2022-09-07 [1] Github (alan-turing-institute/distr6@8cae431) #> dplyr 1.0.10 2022-09-01 [1] CRAN (R 4.2.0) #> evaluate 0.18 2022-11-07 [1] CRAN (R 4.2.1) #> fansi 1.0.3 2022-03-24 [1] CRAN (R 4.2.0) #> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.2.0) #> fs 1.5.2 2021-12-08 [1] CRAN (R 4.2.0) #> future 1.29.0 2022-11-06 [1] CRAN (R 4.2.1) #> generics 0.1.3 2022-07-05 [1] CRAN (R 4.2.0) #> ggplot2 3.4.0 2022-11-04 [1] CRAN (R 4.2.1) #> globals 0.16.2 2022-11-21 [1] CRAN (R 4.2.2) #> glue 1.6.2 2022-02-24 [1] CRAN (R 4.2.0) #> gtable 0.3.1 2022-09-01 [1] CRAN (R 4.2.0) #> highr 0.9 2021-04-16 [1] CRAN (R 4.2.0) #> htmltools 0.5.3 2022-07-18 [1] CRAN (R 4.2.1) #> knitr 1.41 2022-11-18 [1] CRAN (R 4.2.1) #> lattice 0.20-45 2021-09-22 [1] CRAN (R 4.2.0) #> lgr 0.4.4 2022-09-05 [1] CRAN (R 4.2.1) #> lifecycle 1.0.3 2022-10-07 [1] CRAN (R 4.2.1) #> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.2.0) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.2.0) #> Matrix 1.5-3 2022-11-11 [1] CRAN (R 4.2.0) #> mlr3 * 0.14.1 2022-12-01 [1] Github (mlr-org/mlr3@6cd6ee3) #> mlr3extralearners * 0.6.0-9000 2022-12-03 [1] Github (mlr-org/mlr3extralearners@5c7780e) #> mlr3misc 0.11.0 2022-09-22 [1] CRAN (R 4.2.0) #> mlr3pipelines 0.4.2-9000 2022-11-23 [1] Github (mlr-org/mlr3pipelines@5d8d8ab) #> mlr3proba * 0.4.15 2022-12-07 [1] Github (mlr-org/mlr3proba@e208ec0) #> mlr3viz 0.5.10 2022-08-15 [1] CRAN (R 4.2.0) #> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.2.0) #> ooplah 0.2.0 2022-01-21 [1] CRAN (R 4.2.0) #> palmerpenguins 0.1.1 2022-08-15 [1] CRAN (R 4.2.0) #> paradox 0.10.0.9000 2022-11-20 [1] Github (mlr-org/paradox@3be7ba6) #> parallelly 1.32.1 2022-07-21 [1] CRAN (R 4.2.0) #> param6 0.2.4 2022-09-07 [1] Github (xoopR/param6@0fa3577) #> pillar 1.8.1 2022-08-19 [1] CRAN (R 4.2.0) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.2.0) #> pracma 2.4.2 2022-09-22 [1] CRAN (R 4.2.0) #> purrr 0.3.5 2022-10-06 [1] CRAN (R 4.2.1) #> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.2.0) #> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.2.0) #> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.2.0) #> R.utils 2.12.2 2022-11-11 [1] CRAN (R 4.2.1) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.2.0) #> Rcpp 1.0.9 2022-07-08 [1] CRAN (R 4.2.1) #> reprex 2.0.2 2022-08-17 [1] CRAN (R 4.2.0) #> rlang 1.0.6 2022-09-24 [1] CRAN (R 4.2.0) #> rmarkdown 2.18 2022-11-09 [1] CRAN (R 4.2.1) #> rstudioapi 0.14 2022-08-22 [1] CRAN (R 4.2.1) #> scales 1.2.1 2022-08-20 [1] CRAN (R 4.2.0) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.2.0) #> set6 0.2.5 2022-09-07 [1] Github (xoopR/set6@e65ffee) #> stringi 1.7.8 2022-07-11 [1] CRAN (R 4.2.0) #> stringr 1.5.0 2022-12-02 [1] CRAN (R 4.2.0) #> styler 1.8.1.9000 2022-12-03 [1] Github (r-lib/styler@d137eb6) #> survival 3.4-0 2022-08-09 [1] CRAN (R 4.2.0) #> tibble 3.1.8 2022-07-22 [1] CRAN (R 4.2.0) #> tidyselect 1.2.0 2022-10-10 [1] CRAN (R 4.2.0) #> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.2.0) #> uuid 1.1-0 2022-04-19 [1] CRAN (R 4.2.0) #> vctrs 0.5.1 2022-11-16 [1] CRAN (R 4.2.1) #> withr 2.5.0 2022-03-03 [1] CRAN (R 4.2.0) #> xfun 0.35 2022-11-16 [1] CRAN (R 4.2.1) #> yaml 2.3.6 2022-10-18 [1] CRAN (R 4.2.0) #> #> [1] /Users/Lukas/Library/R/arm64/4.2/library #> [2] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library #> #> ────────────────────────────────────────────────────────────────────────────── ```
sebffischer commented 1 year ago

So I suggest to:

  1. Replace the value 1 in mtry = max(ceiling(mtry_ratio * n_features), 1) to 2
  2. Increase the lower bound as you suggested

@bcjaeger Do you think that is a good idea?

jemus42 commented 1 year ago

I'm a little worried about tuning though.
Let's say you do a random search over mtry_ratio in [0, 1], uniformly sampling in that range will then lead to a "clump" at mtry = 2 for low values of mtry_ratio (depending on n_features), and I'm not sure whether that's problematic or acceptable.

bcjaeger commented 1 year ago

Hi! Thanks for noticing this,

@bcjaeger Do you think that is a good idea?

I do, both 1. and 2. make sense. I could remove this error from orsf() if that would be more helpful. It's possible to fit oblique random forests with one predictor variable, it's just not really 'oblique'.

Let's say you do a random search over mtry_ratio in [0, 1], uniformly sampling in that range will then lead to a "clump" at mtry = 2 for low values of mtry_ratio (depending on n_features), and I'm not sure whether that's problematic or acceptable.

This is a good point. For orsf learners, could we restrict the mtry_ratio search to cover [alpha, 1], where alpha is the minimally acceptable mtry_ratio? (the value of mtry_ratio that would give mtry = 2.)

jemus42 commented 1 year ago

For orsf learners, could we restrict the mtry_ratio search to cover [alpha, 1], where alpha is the minimally acceptable mtry_ratio?

I don't know if/how we can, since alpha would depend on n_features, which I don't think can be used in a search space / ParamSet for tuning. We can specify e.g. mtry_ratio = p_dbl(0, 1) - which I guess is a thing in the first place because tuning mtry directly would require knowledge of n_features to find a meaningful upper bound. And in a benchmark where n_features varies, that's a littly icky :/

Allowing mtry = 1 would be "easy", I agree, but you are of course right that at this point it would be something like a degenerate case - kind of like running a random forest but setting n_tree = 1 🥴

bcjaeger commented 1 year ago

That definitely makes sense.

I will check out aorsf and see if allowing mtry = 1 causes any unexpected failures in the C++ routines. If it doesn't, I think I can change the error for 1 predictor to be a warning that indicates oblique splitting doesn't really apply if only one predictor is supplied.

If aorsf gives a warning instead of an error in this scenario, would it resolve this issue?

sebffischer commented 1 year ago

I think a warning would be better and would solve the issue.

bcjaeger commented 1 year ago

It looks like allowing mtry = 1 is going to work out for aorsf. I've pushed this change to the main branch in ropensci/aorsf and will submit to CRAN next week.

As far as changes in mlr3 go, I would still be in support of @sebffischer's solution if you would like to avoid seeing a warning message about 1 predictor being used in orsf().

sebffischer commented 1 year ago

Unfortunately I don't really know the algorithm. Is it possible that there are situations where one predictor might be better than 2? In that case I would even argue that the warning should be submitted and users should know what they are doing when they are setting parameters. I would argue the warning belongs into the documentation and not the code.

But I think it is your call @bcjaeger

bcjaeger commented 1 year ago

Decision trees can be axis based or oblique. Axis based trees use one predictor to partition data into two new branches of the tree, while oblique trees use a linear combination of predictors. A linear combination of one predictor is the same thing as the original predictor as far as a decision tree split is concerned, so an oblique tree with one predictor is the same thing as an axis based one.

I'm okay with putting the warning into documentation. I would guess that using mtry = 1 while fitting an orsf() model is not likely to be common outside of tuning routines, so a note in the docs is probably all we need. I have to travel today and this weekend but I should be able to get this done early next week.

sebffischer commented 1 year ago

Thanks a lot! :)

bcjaeger commented 1 year ago

Hello! Just wanted to let you know orsf() allows mtry= 1 now and does not present a warning:

library(aorsf)
orsf(pbc_orsf, time + status ~ . -id, mtry = 1)
#> ---------- Oblique random survival forest
#> 
#>      Linear combinations: Accelerated
#>           N observations: 276
#>                 N events: 111
#>                  N trees: 500
#>       N predictors total: 17
#>    N predictors per node: 1
#>  Average leaves per tree: 20
#> Min observations in leaf: 5
#>       Min events in leaf: 1
#>           OOB stat value: 0.84
#>            OOB stat type: Harrell's C-statistic
#>      Variable importance: anova
#> 
#> -----------------------------------------

Created on 2022-12-14 with reprex v2.0.2

This change will be in version 0.0.5 of aorsf, which I submitted to CRAN earlier today.