tidymodels / parsnip

A tidy unified interface to models
https://parsnip.tidymodels.org
Other
599 stars 89 forks source link

oblique random forests for classification and regression #1116

Closed bcjaeger closed 6 months ago

bcjaeger commented 6 months ago

Hello!

aorsf has recently been updated to allow for oblique classification and regression forests. May I submit a PR that would add a classification and regression mode for the aorsf engine?

There are a few datasets where the oblique random forest is really helpful (e.g., modeldata::meats)

suppressPackageStartupMessages({
  library(modeldata)
  library(rsample)
  library(recipes)
  library(workflows)
  library(workflowsets)
  library(yardstick)
})
#> Warning: package 'modeldata' was built under R version 4.3.3
#> Warning: package 'yardstick' was built under R version 4.3.3

# load my branch
devtools::load_all(path = "D:/parsnip/")
#> ℹ Loading parsnip

meat_rec <- 
  recipe(protein ~ ., data = meats) %>%
  step_select(-water, -fat)

meat_folds <- vfold_cv(meats)

meat_models <- list(oblique = rand_forest(mode = 'regression', 
                                          engine = 'aorsf'),
                    axis = rand_forest(mode = 'regression',
                                       engine = 'ranger'),
                    xgb = boost_tree(mode = 'regression', 
                                     engine = 'xgboost',
                                     trees = 500))

workflows <- workflow_set(list(meat_rec), meat_models, cross = TRUE)

res <- workflows %>% 
  workflow_map("fit_resamples", 
               verbose = TRUE,
               resamples = meat_folds,
               metrics = metric_set(rsq))
#> i 1 of 3 resampling: recipe_oblique
#> ✔ 1 of 3 resampling: recipe_oblique (3.8s)
#> i 2 of 3 resampling: recipe_axis
#> ✔ 2 of 3 resampling: recipe_axis (2.1s)
#> i 3 of 3 resampling: recipe_xgb
#> ✔ 3 of 3 resampling: recipe_xgb (6.2s)

collect_metrics(res)
#> # A tibble: 3 × 9
#>   wflow_id       .config    preproc model .metric .estimator  mean     n std_err
#>   <chr>          <chr>      <chr>   <chr> <chr>   <chr>      <dbl> <int>   <dbl>
#> 1 recipe_oblique Preproces… recipe  rand… rsq     standard   0.944    10 0.00858
#> 2 recipe_axis    Preproces… recipe  rand… rsq     standard   0.529    10 0.0582 
#> 3 recipe_xgb     Preproces… recipe  boos… rsq     standard   0.524    10 0.0574

Created on 2024-05-03 with reprex v2.1.0

simonpcouch commented 6 months ago

Duplicate of https://github.com/tidymodels/bonsai/issues/73. Closing so as not to track duplicate issue, but we're certainly interested in making this happen!

Looks like you've got an implementation put together locally? I'd be more than happy to work with you to get this merged into bonsai if you're game to start a PR over there. :)

bcjaeger commented 6 months ago

Thank you! I didn't know about bonsai, but it looks awesome. =] I will open a PR there soon.

github-actions[bot] commented 6 months ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.