tidymodels / planning

Documents to plan and discuss future development
MIT License
37 stars 4 forks source link

option for OOB in bagging (e.g., RF) #25

Open hardin47 opened 2 years ago

hardin47 commented 2 years ago

Feature

In situations when running random forests (or other bagged models), OOB model information (predictions, error rates, etc.) should be available.

  1. First of all, I'm not convinced that OOB is a bad option. In this recent paper they say:

In line with results reported in the literature [5], the use of stratified subsampling with sampling fractions that are proportional to response class sizes of the training data yielded almost unbiased error rates in most settings with metric predictors. It therefore presents an easy way of reducing the bias in the OOB error. It does not increase the cost of constructing the RF, since unstratified sampling (bootstrap of subsampling) is simply replaced by stratified subsampling.

Indicating that OOB errors are doing a good job of estimating error rates (with the added benefit that they require no additional model fitting) as long as stratified sampling is done instead of subsampling.

  1. Even if nested resampling is superior (and I'll buy that there is an argument to be made), I find that cross validation and OOB are stepping stones to understanding nested resampling. Do you argue that nested resampling is better than CV? If so, why have CV in the package? Again, OOB happens for free, and sometimes nested resampling isn't even that much better. I think that more people will use nested resampling if they understand OOB, and the path to understanding OOB happens when it is included in the tidymodels package.

Thanks for all that you do!! The tidymodels package is amazing, and I really appreciate all the hard work that has gone into creating it.

EmilHvitfeldt commented 2 years ago

Hello @hardin47! I'm not quite sure if I follow what your request is. Could the clarify what you find hard/impossible to do using the tidymodels framework? 😃

The tidymodels packages (parsnip in this instance) don't handle the OOB tasks and those calculations are delegated to the engine. As an example see below a case where a random forest model is fit, and a combination of extract_fit_engine() and pluck() is used to pull out the OOB predictions and errors. This will of cause vary from engine to engine.

library(tidymodels)

rf_spec <- rand_forest() %>%
  set_mode("regression") %>%
  set_engine("ranger")

rf_wf <- workflow() %>%
  add_model(rf_spec) %>%
  add_formula(mpg ~ .)

wf_fit <- fit(rf_wf, mtcars)

# OOB predictions
wf_fit %>%
  extract_fit_engine() %>%
  pluck("predictions")
#>  [1] 20.41746 20.48551 26.40105 18.52579 16.62929 19.92354 15.04733 22.70223
#>  [9] 22.28012 18.80575 19.79037 16.26322 15.82583 16.16870 13.64493 12.83027
#> [17] 12.58347 27.85033 29.56797 29.24489 23.31117 16.60676 17.87619 15.68350
#> [25] 16.14184 30.89038 25.48595 25.26826 17.24987 19.80717 15.76371 23.88356

# OOB prediction error (MSE)
wf_fit %>%
  extract_fit_engine() %>%
  pluck("prediction.error")
#> [1] 5.60741
hardin47 commented 2 years ago

Huge apologies for being unclear!!! You are absolutely correct that the OOB pieces can be captured using pluck(). But I use tidymodels (especially in my teaching) to underscore the consistency of modeling and the structure by which we think about the information. I know this sounds silly, but extract_fit_engine() and pluck() are a pretty difficult add to the tidymodels series of steps that I'm teaching.

What I'd like is for OOB to fit seamlessly into the tidymodels workflow. I want to to get OOB errors (or MSE or whatever) as part of the tuning process, tune_grid(). Does that make sense?

Thanks again for everything!

library(tidymodels)

rf_spec <- rand_forest(mtry = tune()) %>%
  set_mode("regression") %>%
  set_engine("ranger")

rf_wf <- workflow() %>%
  add_model(rf_spec) %>%
  add_formula(mpg ~ .)

# wf_fit <- fit(rf_wf, mtcars)

rf_vfold <- vfold_cv(mtcars,
                     v = 3)

mtry_grid <- data.frame(mtry = seq(1, 3, 1))

rf_wf %>%
  tune_grid(resamples = rf_vfold,
            grid = mtry_grid) %>%
  collect_metrics() %>%
  filter(.metric == "rmse")

#> # A tibble: 3 × 7
#>    mtry .metric .estimator  mean     n std_err .config             
#>   <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
#> 1     1 rmse    standard    2.82     3   0.257 Preprocessor1_Model1
#> 2     2 rmse    standard    2.52     3   0.212 Preprocessor1_Model2
#> 3     3 rmse    standard    2.40     3   0.306 Preprocessor1_Model3

Created on 2022-02-21 by the reprex package (v2.0.1)

EmilHvitfeldt commented 2 years ago

No need to apologize! I understand now! Yes, both extract_fit_engine() and pluck() are a little more to introduce on top of everything else. It is still technically possible to extract these values in the tune_grid() workflow using the extract argument of control_grid(), see reprex below.

I know that that doesn't alleviate your problems completely but it might get you a little closer. We are aware that pulling out the extracted values is not ideal at the moment but we have plans to remedy that: https://github.com/tidymodels/tune/issues/409.

As for extract_fit_engine() and pluck(), in theory a extract_oob() generic could be written that would extract the information depending on the engine used. I'm not quite sure where such a function should live right now.

library(tidymodels)

rf_spec <- rand_forest(mtry = tune()) %>%
  set_mode("regression") %>%
  set_engine("ranger")

rf_wf <- workflow() %>%
  add_model(rf_spec) %>%
  add_formula(mpg ~ .)

# wf_fit <- fit(rf_wf, mtcars)

rf_vfold <- vfold_cv(mtcars,
                     v = 3)

mtry_grid <- data.frame(mtry = seq(1, 3, 1))

extract_oob <- function(x) {
  x %>%
    extract_fit_engine() %>%
    pluck("prediction.error")
}

rf_wf %>%
  tune_grid(resamples = rf_vfold,
            grid = mtry_grid, 
            control = control_grid(extract = extract_oob)) %>%
  unnest(.extracts) %>%
  unnest(.extracts)
#> # A tibble: 9 × 7
#>   splits          id    .metrics         .notes    mtry .extracts .config       
#>   <list>          <chr> <list>           <list>   <dbl>     <dbl> <chr>         
#> 1 <split [21/11]> Fold1 <tibble [6 × 5]> <tibble>     1     11.6  Preprocessor1…
#> 2 <split [21/11]> Fold1 <tibble [6 × 5]> <tibble>     2     11.1  Preprocessor1…
#> 3 <split [21/11]> Fold1 <tibble [6 × 5]> <tibble>     3     10.9  Preprocessor1…
#> 4 <split [21/11]> Fold2 <tibble [6 × 5]> <tibble>     1      8.70 Preprocessor1…
#> 5 <split [21/11]> Fold2 <tibble [6 × 5]> <tibble>     2      7.10 Preprocessor1…
#> 6 <split [21/11]> Fold2 <tibble [6 × 5]> <tibble>     3      6.47 Preprocessor1…
#> 7 <split [22/10]> Fold3 <tibble [6 × 5]> <tibble>     1      8.38 Preprocessor1…
#> 8 <split [22/10]> Fold3 <tibble [6 × 5]> <tibble>     2      6.39 Preprocessor1…
#> 9 <split [22/10]> Fold3 <tibble [6 × 5]> <tibble>     3      5.71 Preprocessor1…

Created on 2022-02-21 by the reprex package (v2.0.1)

hardin47 commented 2 years ago

i'll have to play around with that to make sure i understand what is happening with extract_oob(). thanks for this helpful example, and i look forward to all the ways that tidymodels is growing.

topepo commented 2 years ago

This is a good idea and I think that we should try to solve this systematically (and not just for ranger).

Other models have OOB errors but they come back in a different format (e.g. a OOB confusion table, etc). We might not be able to do something comprehensive across all models.

I think I have a solution but I won't be able to get to it right away. I've put a moratorium on new packages/features until we have made a lot of progress on case weights.

The idea would be to produce a tibble of specific characteristics of models. For example:

etc. We would have an option to bundle these statistics into the results of the tune functions.

I could add a set of OOB statistics for ranger in the process of doing this.

A side note: you would probably want to avoid any external resampling if you can get OOB errors. In that case, you can use the (poorly named by me) apparent() function to make a resampling object. This specifies that the modeling and holdout sets are the same. This would avoid making multiple versions of the data to estimate performance.

juliasilge commented 2 years ago

@hardin47 Would you be up for creating a pull request to our planning repo outlining the discussion here?

hardin47 commented 2 years ago

Yes, I'd love to! But it won't happen in the next few weeks. Is it something that could wait?

juliasilge commented 2 years ago

Yes @hardin47 for sure