tidymodels / bonsai

parsnip wrappers for tree-based models
https://bonsai.tidymodels.org
Other
51 stars 7 forks source link

aorsf support for `mtry_prop` #87

Open cgoo4 opened 1 month ago

cgoo4 commented 1 month ago

aorsf is a great addition to bonsai! Any chance of supporting mtry_prop?

library(tidymodels)
library(bonsai)

set.seed(1)
folds <- vfold_cv(mtcars, v = 5)

rec <- recipe(cyl ~ ., data = mtcars)

mod_lgbm <- boost_tree(mtry = tune()) |> 
  set_engine("lightgbm", count = FALSE) |>
  set_mode("regression")

mod_aorsf <- rand_forest(mtry = tune()) |> 
  set_engine("aorsf", count = FALSE) |>
  set_mode("regression")

lgbm_wflow <- workflow() |>
  add_model(mod_lgbm) |>
  add_recipe(rec)

aorsf_wflow <- workflow() |>
  add_model(mod_aorsf) |>
  add_recipe(rec)

# lightgbm supports mtry_prop
param_info <-
  lgbm_wflow |>
  extract_parameter_set_dials() |>
  update(mtry = mtry_prop(c(0, 1)))

tune_grid(
  lgbm_wflow, 
  resamples = folds, 
  param_info = param_info,
  metrics = metric_set(rmse)
  )
#> # Tuning results
#> # 5-fold cross-validation 
#> # A tibble: 5 × 4
#>   splits         id    .metrics          .notes          
#>   <list>         <chr> <list>            <list>          
#> 1 <split [25/7]> Fold1 <tibble [10 × 5]> <tibble [0 × 3]>
#> 2 <split [25/7]> Fold2 <tibble [10 × 5]> <tibble [0 × 3]>
#> 3 <split [26/6]> Fold3 <tibble [10 × 5]> <tibble [0 × 3]>
#> 4 <split [26/6]> Fold4 <tibble [10 × 5]> <tibble [0 × 3]>
#> 5 <split [26/6]> Fold5 <tibble [10 × 5]> <tibble [0 × 3]>

# could aorsf do the same?
param_info <-
  aorsf_wflow |>
  extract_parameter_set_dials() |>
  update(mtry = mtry_prop(c(0, 1)))

tune_grid(
  aorsf_wflow, 
  resamples = folds, 
  param_info = param_info,
  metrics = metric_set(rmse)
  )
#> → A | error:   there were unrecognized arguments:
#>                  count is unrecognized - did you mean control?
#> There were issues with some computations   A: x1
#> There were issues with some computations   A: x50
#> 
#> Warning: All models failed. Run `show_notes(.Last.tune.result)` for more
#> information.
#> # Tuning results
#> # 5-fold cross-validation 
#> # A tibble: 5 × 4
#>   splits         id    .metrics .notes           
#>   <list>         <chr> <list>   <list>           
#> 1 <split [25/7]> Fold1 <NULL>   <tibble [10 × 3]>
#> 2 <split [25/7]> Fold2 <NULL>   <tibble [10 × 3]>
#> 3 <split [26/6]> Fold3 <NULL>   <tibble [10 × 3]>
#> 4 <split [26/6]> Fold4 <NULL>   <tibble [10 × 3]>
#> 5 <split [26/6]> Fold5 <NULL>   <tibble [10 × 3]>
#> 
#> There were issues with some computations:
#> 
#>   - Error(s) x50: there were unrecognized arguments:   count is unrecognized - did ...
#> 
#> Run `show_notes(.Last.tune.result)` for more information.

Created on 2024-07-21 with reprex v2.1.1

simonpcouch commented 1 month ago

Thanks for the issue!

A few notes from my first time wrapping my head around this here.

What makes this more straightforward for engines like lightgbm or xgboost is that the tidymodels team implements a wrapper function around the actual training function (for lightgbm, that's bonsai::train_lightgbm()) that are straightforward to edit from our perspective. Since aorsf has a straightforward interface, we just pass data and arguments directly to aorsf::orsf(). This makes things more maintainable from our perspective, but less flexible as well.

Our options would be:

1) Implement a small wrapper around aorsf::orsf() where we just call bonsai:::process_mtry. 2) Wait for parsnip to have more full-fledged support for this, e.g. https://github.com/tidymodels/parsnip/issues/602