tidymodels / finetune

Additional functions for model tuning
https://finetune.tidymodels.org/
Other
62 stars 8 forks source link

Feature: collect_race() or collect_metrics() for the race #99

Closed jrosell closed 10 months ago

jrosell commented 10 months ago

Feature

In situations when one wants to analyze the intermediate results of a race, one shouldn't be required to know the internal data structure of the tune package and be able to use some function like collect_race or similar.

Here's an example of what we get with the current functions and what I expect to get.

project_name <- "sliced-s01e09-playoffs-1"
output_dir <- here::here(project_name, "data")
dir.create(file.path(output_dir), showWarnings = FALSE, recursive = TRUE)
kaggler::kgl_competitions_data_download_all(project_name, output_dir = output_dir)
library(tidyverse)
library(tidymodels)
library(finetune)
options(readr.show_col_types = FALSE)
theme_set(theme_light())

train_raw <- read_csv(here::here(output_dir, "train.csv"))

set.seed(123)
bb_split <- train_raw %>%
    mutate(
        is_home_run = if_else(as.logical(is_home_run), "HR", "no"),
        is_home_run = factor(is_home_run)
    ) %>%
    na.omit() %>% 
    sample_n(5000) %>% 
    initial_split(strata = is_home_run)
bb_train <- training(bb_split)
bb_test <- testing(bb_split)
set.seed(234)
bb_folds <- vfold_cv(bb_train, strata = is_home_run, v = 10)

bb_rec <-
    recipe(is_home_run ~ launch_angle + launch_speed + plate_x + plate_z +
               bb_type + bearing + pitch_mph +
               is_pitcher_lefty + is_batter_lefty +
               inning + balls + strikes + game_date,
           data = bb_train
    ) %>%
    step_date(game_date, features = c("week"), keep_original_cols = FALSE) %>%
    step_unknown(all_nominal_predictors()) %>%
    step_dummy(all_nominal_predictors())

xgb_wf <- bb_rec %>% 
    workflow(
        boost_tree(
            mode = "classification",
            trees = tune(),
            min_n = tune(),
            mtry = tune(),
            learn_rate = tune(),
            tree_depth = tune(),
            loss_reduction = tune(),
            sample_size = tune()
        ) %>%
        set_engine("xgboost", counts = FALSE)
    )

set.seed(123)
xgb_grid <- xgb_wf %>% 
    extract_parameter_set_dials() %>% 
    update(
        trees = trees(c(100, 100)),
        min_n = min_n(c(1, 300)),
        mtry = mtry_prop(c(0.1, 0.4)),
        learn_rate = learn_rate(c(0.3, 0.3)),
        tree_depth = tree_depth(c(2, 6)),
        loss_reduction = loss_reduction(c(0, 0), trans = NULL),
        sample_size = sample_prop(c(0.4, 0.8))
    ) %>% 
    grid_max_entropy(size = 10)

cores <- parallelly::availableCores(omit = 15)
if(cores > 1) {
    print(paste("Using", cores, "cores"))
    doParallel::registerDoParallel(cores)
}
#> [1] "Using 5 cores"
set.seed(345)
xgb_rs <- tune_race_anova(
    xgb_wf,
    resamples = bb_folds,
    grid = xgb_grid,
    metrics = metric_set(mn_log_loss),
    control = control_race(verbose_elim = TRUE)
)
#> ℹ Racing will minimize the mn_log_loss metric.
#> ℹ Resamples are analyzed in a random order.
#> ℹ Fold10: 8 eliminated; 2 candidates remain.
#> 
#> ℹ Fold07: All but one parameter combination were eliminated.
if(cores > 1) {
    doParallel::stopImplicitCluster()
}

xgb_rs %>% show_best(metric = "mn_log_loss")
#> # A tibble: 1 × 13
#>    mtry trees min_n tree_depth learn_rate loss_reduction sample_size .metric    
#>   <dbl> <int> <int>      <int>      <dbl>          <dbl>       <dbl> <chr>      
#> 1 0.277   100    47          5       2.00              0       0.632 mn_log_loss
#> # ℹ 5 more variables: .estimator <chr>, mean <dbl>, n <int>, std_err <dbl>,
#> #   .config <chr>
# # A tibble: 1 × 13
# mtry trees min_n tree_depth learn_rate loss_reduction sample_size .metric     .estimator  mean     n std_err .config              
# <dbl> <int> <int>      <int>      <dbl>          <dbl>       <dbl> <chr>       <chr>      <dbl> <int>   <dbl> <chr>                
# 1 0.277   100    47          5       2.00              0       0.632 mn_log_loss binary     0.107    10 0.00630 Preprocessor1_Model06

xgb_rs %>% collect_metrics(summarize = FALSE) %>% arrange(.estimate)
#> # A tibble: 10 × 12
#>    id      mtry trees min_n tree_depth learn_rate loss_reduction sample_size
#>    <chr>  <dbl> <int> <int>      <int>      <dbl>          <dbl>       <dbl>
#>  1 Fold07 0.277   100    47          5       2.00              0       0.632
#>  2 Fold04 0.277   100    47          5       2.00              0       0.632
#>  3 Fold01 0.277   100    47          5       2.00              0       0.632
#>  4 Fold08 0.277   100    47          5       2.00              0       0.632
#>  5 Fold10 0.277   100    47          5       2.00              0       0.632
#>  6 Fold09 0.277   100    47          5       2.00              0       0.632
#>  7 Fold05 0.277   100    47          5       2.00              0       0.632
#>  8 Fold02 0.277   100    47          5       2.00              0       0.632
#>  9 Fold06 0.277   100    47          5       2.00              0       0.632
#> 10 Fold03 0.277   100    47          5       2.00              0       0.632
#> # ℹ 4 more variables: .metric <chr>, .estimator <chr>, .estimate <dbl>,
#> #   .config <chr>

xgb_rs %>%
    dplyr::select(id, .order, .metrics) %>%
    tidyr::unnest(cols = .metrics) %>% 
    dplyr::group_by(!!!rlang::syms(attributes(xgb_rs)$parameters$id), .metric, .estimator) %>%
    dplyr::summarize(
        mean = mean(.estimate, na.rm = TRUE),
        n = sum(!is.na(.estimate)),
        std_err = sd(.estimate, na.rm = TRUE) / sqrt(n),
        .groups = "drop"
    ) %>% 
    arrange(mean) %>% 
    print(n = Inf)
#> # A tibble: 10 × 12
#>     mtry trees min_n tree_depth learn_rate loss_reduction sample_size .metric   
#>    <dbl> <int> <int>      <int>      <dbl>          <dbl>       <dbl> <chr>     
#>  1 0.277   100    47          5       2.00              0       0.632 mn_log_lo…
#>  2 0.310   100    68          4       2.00              0       0.754 mn_log_lo…
#>  3 0.256   100   107          2       2.00              0       0.527 mn_log_lo…
#>  4 0.353   100    63          3       2.00              0       0.540 mn_log_lo…
#>  5 0.343   100    67          5       2.00              0       0.615 mn_log_lo…
#>  6 0.304   100   264          4       2.00              0       0.751 mn_log_lo…
#>  7 0.207   100   158          4       2.00              0       0.418 mn_log_lo…
#>  8 0.115   100   120          4       2.00              0       0.629 mn_log_lo…
#>  9 0.198   100    98          2       2.00              0       0.534 mn_log_lo…
#> 10 0.137   100   209          4       2.00              0       0.725 mn_log_lo…
#> # ℹ 4 more variables: .estimator <chr>, mean <dbl>, n <int>, std_err <dbl>

Created on 2024-01-18 with reprex v2.1.0.9000

simonpcouch commented 10 months ago

Thanks for the issue, @jrosell!

It seems like the all_configs argument to the tune_race collect_metrics() method might be helpful for you!

library(tidymodels)
library(finetune)
library(discrim)
#> 
#> Attaching package: 'discrim'
#> The following object is masked from 'package:dials':
#> 
#>     smoothness

data(two_class_dat, package = "modeldata")

set.seed(6376)
rs <- bootstraps(two_class_dat, times = 10)

# optimize an regularized discriminant analysis model
rda_spec <-
  discrim_regularized(frac_common_cov = tune(), frac_identity = tune()) %>%
  set_engine("klaR")

ctrl <- control_race(verbose_elim = TRUE)
set.seed(11)
grid_anova <-
  rda_spec %>%
  tune_race_anova(Class ~ ., resamples = rs, grid = 10, control = ctrl)
#> ℹ Racing will maximize the roc_auc metric.
#> ℹ Resamples are analyzed in a random order.
#> ℹ Bootstrap05: All but one parameter combination were eliminated.

plot_race(grid_anova)

A quick visual of the racing process:

This is reasonably well-reflected in collect_metrics() output:


collect_metrics(grid_anova, all_configs = TRUE)
#> # A tibble: 20 × 8
#>    frac_common_cov frac_identity .metric  .estimator  mean     n std_err .config
#>              <dbl>         <dbl> <chr>    <chr>      <dbl> <int>   <dbl> <chr>  
#>  1          0.0691        0.0437 accuracy binary     0.811    10 0.00578 Prepro…
#>  2          0.0691        0.0437 roc_auc  binary     0.886    10 0.00513 Prepro…
#>  3          0.199         0.595  accuracy binary     0.733     3 0.0139  Prepro…
#>  4          0.199         0.595  roc_auc  binary     0.825     3 0.00535 Prepro…
#>  5          0.962         0.716  accuracy binary     0.719     3 0.0118  Prepro…
#>  6          0.962         0.716  roc_auc  binary     0.814     3 0.00526 Prepro…
#>  7          0.271         0.910  accuracy binary     0.709     3 0.0130  Prepro…
#>  8          0.271         0.910  roc_auc  binary     0.798     3 0.00501 Prepro…
#>  9          0.781         0.666  accuracy binary     0.726     3 0.0126  Prepro…
#> 10          0.781         0.666  roc_auc  binary     0.818     3 0.00529 Prepro…
#> 11          0.481         0.453  accuracy binary     0.751     3 0.0117  Prepro…
#> 12          0.481         0.453  roc_auc  binary     0.839     3 0.00524 Prepro…
#> 13          0.837         0.824  accuracy binary     0.711     3 0.0142  Prepro…
#> 14          0.837         0.824  roc_auc  binary     0.805     3 0.00519 Prepro…
#> 15          0.605         0.385  accuracy binary     0.754     3 0.0117  Prepro…
#> 16          0.605         0.385  roc_auc  binary     0.846     3 0.00522 Prepro…
#> 17          0.555         0.293  accuracy binary     0.774     3 0.0159  Prepro…
#> 18          0.555         0.293  roc_auc  binary     0.856     3 0.00485 Prepro…
#> 19          0.392         0.154  accuracy binary     0.790     3 0.0133  Prepro…
#> 20          0.392         0.154  roc_auc  binary     0.871     3 0.00459 Prepro…

Created on 2024-01-18 with reprex v2.1.0

Note the n column, specifically. summarize = FALSE would give the performance metrics for each configuration by resample, including those that weren't resampled fully if all_configs = TRUE.

jrosell commented 10 months ago

My fault. I was checking tune docs instead of finetune docs. https://finetune.tidymodels.org/reference/collect_predictions.html