tidymodels / tune

Tools for tidy parameter tuning
https://tune.tidymodels.org
Other
277 stars 42 forks source link

Feature suggestion: Extract splits from tune results as a resampling object #947

Open jrosell opened 6 days ago

jrosell commented 6 days ago

Feature suggestion

Now that we have the new {tailor} package for post-processing in titydmodels, I find myself in the need to reuse the splits from tune_results as a resampling object.

I believe this new extract_resamples function (or whatever name you prefer) could improve the interactive usage of tidymodels.

Here a minimal reproducible example to demonstrate its use:

# pak::pak(
#   paste0(
#     "tidymodels/",
#     c("tune", "workflows", "rsample", "tailor")
#   )
# )
library(tidyverse)
library(tidymodels)
library(probably)
#> 
#> Attaching package: 'probably'
#> The following objects are masked from 'package:base':
#> 
#>     as.factor, as.ordered
library(tailor)
library(stacks)

# How well are our predictions calibrated?  Not so well
data(deliveries)
set.seed(1)
delivery_split <- initial_split(deliveries)
delivery_train <- training(delivery_split)
delivery_test  <- testing(delivery_split)
set.seed(1)
delivery_folds <- vfold_cv(delivery_train)
delivery_res <-
  workflow() %>%
  add_formula(time_to_delivery ~ .) %>%
  add_model(boost_tree(mode = "regression", trees = 3)) |> 
  fit_resamples(
    delivery_folds, 
    control = control_stack_resamples()
  )
delivery_res |> 
  collect_predictions() |> 
  cal_plot_regression(truth = time_to_delivery, estimate = .pred)

delivery_res |> collect_metrics()
#> # A tibble: 2 × 6
#>   .metric .estimator  mean     n std_err .config             
#>   <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
#> 1 rmse    standard   9.52     10 0.0533  Preprocessor1_Model1
#> 2 rsq     standard   0.853    10 0.00357 Preprocessor1_Model1

# We want to reuse the already saved splits in the tune results as rset
extract_resamples <- \(x) {
  stopifnot(inherits(x, "tune_results"))
  result_rset <- manual_rset(x$splits, x$id)
  new_attrs <- attributes(result_rset)[c("names", "row.names")]
  existing_attrs <- attributes(x)$rset_info$att
  att <- modifyList(existing_attrs, new_attrs)
  desired_classes <- c(att$class, "rset", "tbl_df", "tbl", "data.frame")  
  att$class <- NULL  
  attributes(result_rset) <- att  
  class(result_rset) <- desired_classes
  result_rset
}
waldo::compare(delivery_folds, extract_resamples(delivery_res))
#> ✔ No differences

# Let's adjust numeric calibration extracting the saved splits
delivery_res_improved <-
  delivery_res |> 
  extract_workflow() |> 
  add_tailor(tailor() %>% adjust_numeric_calibration()) |> 
  fit_resamples(
    extract_resamples(delivery_res), 
    control = control_stack_resamples()
  )
delivery_res_improved |> collect_metrics()
#> # A tibble: 2 × 6
#>   .metric .estimator  mean     n std_err .config             
#>   <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
#> 1 rmse    standard   2.71     10 0.0300  Preprocessor1_Model1
#> 2 rsq     standard   0.846    10 0.00432 Preprocessor1_Model1

# Much better
delivery_res_improved |> 
  collect_predictions() |>
  cal_plot_regression(truth = time_to_delivery, estimate = .pred)

sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.3.3 (2024-02-29)
#>  os       Ubuntu 22.04.4 LTS
#>  system   x86_64, linux-gnu
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Europe/Madrid
#>  date     2024-10-09
#>  pandoc   2.9.2.1 @ /bin/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version    date (UTC) lib source
#>  backports      1.4.1      2021-12-13 [1] CRAN (R 4.3.0)
#>  broom        * 1.0.5      2023-06-09 [1] CRAN (R 4.3.1)
#>  butcher        0.3.3      2023-08-23 [1] CRAN (R 4.3.2)
#>  class          7.3-22     2023-05-03 [2] CRAN (R 4.3.3)
#>  cli            3.6.2      2023-12-11 [1] CRAN (R 4.3.2)
#>  codetools      0.2-19     2023-02-01 [2] CRAN (R 4.3.3)
#>  colorspace     2.1-0      2023-01-23 [1] CRAN (R 4.3.0)
#>  data.table     1.15.99    2024-02-20 [1] Github (Rdatatable/data.table@8f8ef93)
#>  dials        * 1.3.0      2024-07-30 [1] RSPM
#>  DiceDesign     1.10       2023-12-07 [1] CRAN (R 4.3.2)
#>  digest         0.6.35     2024-03-11 [1] RSPM (R 4.3.0)
#>  dplyr        * 1.1.4      2023-11-17 [1] CRAN (R 4.3.2)
#>  evaluate       0.23       2023-11-01 [1] CRAN (R 4.3.2)
#>  fansi          1.0.6      2023-12-08 [1] CRAN (R 4.3.2)
#>  farver         2.1.1      2022-07-06 [1] CRAN (R 4.3.0)
#>  fastmap        1.1.1      2023-02-24 [1] CRAN (R 4.3.0)
#>  forcats      * 1.0.0      2023-01-29 [1] CRAN (R 4.3.2)
#>  foreach        1.5.2      2022-02-02 [1] CRAN (R 4.3.0)
#>  fs             1.6.3      2023-07-20 [1] CRAN (R 4.3.1)
#>  furrr          0.3.1      2022-08-15 [1] CRAN (R 4.3.0)
#>  future         1.33.1     2023-12-22 [1] CRAN (R 4.3.2)
#>  future.apply   1.11.1     2023-12-21 [1] CRAN (R 4.3.2)
#>  generics       0.1.3      2022-07-05 [1] CRAN (R 4.3.0)
#>  ggplot2      * 3.5.0      2024-02-23 [1] RSPM (R 4.3.0)
#>  globals        0.16.3     2024-03-08 [1] RSPM (R 4.3.0)
#>  glue           1.7.0      2024-01-09 [1] RSPM (R 4.3.0)
#>  gower          1.0.1      2022-12-22 [1] CRAN (R 4.3.0)
#>  GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.3.0)
#>  gtable         0.3.4      2023-08-21 [1] CRAN (R 4.3.1)
#>  hardhat        1.4.0      2024-06-02 [1] RSPM
#>  hms            1.1.3      2023-03-21 [1] CRAN (R 4.3.0)
#>  htmltools      0.5.8      2024-03-25 [1] RSPM (R 4.3.0)
#>  infer        * 1.0.7      2024-03-25 [1] RSPM (R 4.3.0)
#>  ipred          0.9-14     2023-03-09 [1] CRAN (R 4.3.0)
#>  iterators      1.0.14     2022-02-05 [1] CRAN (R 4.3.0)
#>  jsonlite       1.8.8      2023-12-04 [1] CRAN (R 4.3.2)
#>  knitr          1.45       2023-10-30 [1] CRAN (R 4.3.2)
#>  labeling       0.4.3      2023-08-29 [1] CRAN (R 4.3.1)
#>  lattice        0.22-5     2023-10-24 [2] CRAN (R 4.3.3)
#>  lava           1.8.0      2024-03-05 [1] RSPM (R 4.3.0)
#>  lhs            1.1.6      2022-12-17 [1] CRAN (R 4.3.0)
#>  lifecycle      1.0.4      2023-11-07 [1] CRAN (R 4.3.2)
#>  listenv        0.9.1      2024-01-29 [1] RSPM (R 4.3.0)
#>  lubridate    * 1.9.3      2023-09-27 [1] CRAN (R 4.3.2)
#>  magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.3.0)
#>  MASS           7.3-60.0.1 2024-01-13 [2] CRAN (R 4.3.3)
#>  Matrix         1.6-5      2024-01-11 [1] RSPM (R 4.3.0)
#>  mgcv           1.9-1      2023-12-21 [2] CRAN (R 4.3.3)
#>  modeldata    * 1.3.0      2024-01-21 [1] RSPM (R 4.3.0)
#>  modelenv       0.1.1      2023-03-08 [1] CRAN (R 4.3.0)
#>  munsell        0.5.0      2018-06-12 [1] CRAN (R 4.3.0)
#>  nlme           3.1-164    2023-11-27 [2] CRAN (R 4.3.3)
#>  nnet           7.3-19     2023-05-03 [2] CRAN (R 4.3.3)
#>  parallelly     1.37.1     2024-02-29 [1] RSPM (R 4.3.0)
#>  parsnip      * 1.2.1.9002 2024-10-08 [1] Github (tidymodels/parsnip@5ce414e)
#>  pillar         1.9.0      2023-03-22 [1] CRAN (R 4.3.0)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.3.0)
#>  probably     * 1.0.3.9001 2024-10-08 [1] Github (tidymodels/probably@545f9ab)
#>  prodlim        2023.08.28 2023-08-28 [1] CRAN (R 4.3.2)
#>  purrr        * 1.0.2      2023-08-10 [1] CRAN (R 4.3.1)
#>  R.cache        0.16.0     2022-07-21 [1] CRAN (R 4.3.1)
#>  R.methodsS3    1.8.2      2022-06-13 [1] CRAN (R 4.3.1)
#>  R.oo           1.26.0     2024-01-24 [1] CRAN (R 4.3.2)
#>  R.utils        2.12.3     2023-11-18 [1] CRAN (R 4.3.2)
#>  R6             2.5.1      2021-08-19 [1] CRAN (R 4.3.0)
#>  Rcpp           1.0.12     2024-01-09 [1] RSPM (R 4.3.0)
#>  readr        * 2.1.5      2024-01-10 [1] RSPM (R 4.3.0)
#>  recipes      * 1.0.10     2024-02-18 [1] RSPM (R 4.3.0)
#>  reprex         2.1.0.9000 2024-01-18 [1] Github (tidyverse/reprex@e1f65e9)
#>  rlang          1.1.3      2024-01-10 [1] RSPM (R 4.3.0)
#>  rmarkdown      2.26       2024-03-05 [1] RSPM (R 4.3.0)
#>  rpart          4.1.23     2023-12-05 [1] RSPM
#>  rsample      * 1.2.1.9000 2024-10-08 [1] Github (tidymodels/rsample@f799dba)
#>  scales       * 1.3.0      2023-11-28 [1] CRAN (R 4.3.2)
#>  sessioninfo    1.2.2      2021-12-06 [1] CRAN (R 4.3.0)
#>  sparsevctrs    0.1.0.9002 2024-10-08 [1] Github (r-lib/sparsevctrs@b29b723)
#>  stacks       * 1.0.4      2024-03-21 [1] RSPM (R 4.3.0)
#>  stringi        1.8.3      2023-12-11 [1] CRAN (R 4.3.2)
#>  stringr      * 1.5.1      2023-11-14 [1] CRAN (R 4.3.2)
#>  styler         1.10.2     2023-08-29 [1] CRAN (R 4.3.2)
#>  survival       3.5-8      2024-02-14 [2] CRAN (R 4.3.3)
#>  tailor       * 0.0.0.9001 2024-10-08 [1] Github (tidymodels/tailor@317a4db)
#>  tibble       * 3.2.1      2023-03-20 [1] CRAN (R 4.3.0)
#>  tidymodels   * 1.2.0      2024-03-25 [1] RSPM (R 4.3.0)
#>  tidyr        * 1.3.1      2024-01-24 [1] CRAN (R 4.3.2)
#>  tidyselect     1.2.1      2024-03-11 [1] RSPM (R 4.3.0)
#>  tidyverse    * 2.0.0.9000 2024-02-20 [1] Github (tidyverse/tidyverse@62f32d4)
#>  timechange     0.3.0      2024-01-18 [1] RSPM (R 4.3.0)
#>  timeDate       4032.109   2023-12-14 [1] CRAN (R 4.3.2)
#>  tune         * 1.2.1.9000 2024-10-08 [1] Github (tidymodels/tune@f8d734a)
#>  tzdb           0.4.0      2023-05-12 [1] CRAN (R 4.3.0)
#>  utf8           1.2.4      2023-10-22 [1] CRAN (R 4.3.2)
#>  vctrs          0.6.5      2023-12-01 [1] CRAN (R 4.3.2)
#>  waldo          0.5.2      2023-11-02 [1] CRAN (R 4.3.2)
#>  withr          3.0.0      2024-01-16 [1] CRAN (R 4.3.2)
#>  workflows    * 1.1.4.9000 2024-10-08 [1] Github (tidymodels/workflows@78aa5df)
#>  workflowsets * 1.1.0      2024-03-21 [1] RSPM (R 4.3.0)
#>  xfun           0.43       2024-03-25 [1] RSPM (R 4.3.0)
#>  xgboost      * 1.7.7.1    2024-01-25 [1] RSPM (R 4.3.0)
#>  yaml           2.3.8      2023-12-11 [1] CRAN (R 4.3.2)
#>  yardstick    * 1.3.1      2024-03-21 [1] RSPM (R 4.3.0)
#> 
#>  [1] /home/jordi/R/x86_64-pc-linux-gnu-library/4.3
#>  [2] /opt/R/4.3.3/lib/R/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────

Created on 2024-10-09 with [reprex v2.1.0.9000](https://reprex.tidyverse.org/)

This implementation seems to give identical results for my vfold_cv example, but I guess other rset type of objects should be tested.

simonpcouch commented 4 days ago

Could you say a little bit more about why it is that you'd need to extract the splits from the tune_results rather than just reusing the splits you have already?

Note to self: FWIW, we did find a use for a similar helper in stacks:::.set_splits().

jrosell commented 4 days ago

Well. In my pipelines I usually have one process for fitting resamples & tuning and sometimes I only save the tune_resamples object and not the rset... But, then "ups" I need the rset too because I want to check something and I didnt save it. {tailor} could increase the probability of this issue.

Furthermore, I want to try AutoGuon inference approach and this function could help.