tidymodels / finetune

Additional functions for model tuning
https://finetune.tidymodels.org/
Other
62 stars 8 forks source link

Improve checking of metric used in `show_best.tune_race()` #89

Closed hfrick closed 8 months ago

hfrick commented 8 months ago

Following on from https://github.com/tidymodels/extratests/pull/156/files/c147241a882641e12e8c0b89cfdd7aa64817aed4#r1439642019

show_best.tune_race() should only error, not warn and then error, if a metric is used that is not included in the tune_results object.

This will also require updating the corresponding tests in extratests.

library(tidymodels)
library(censored)
#> Loading required package: survival
library(finetune)

data("mlc_churn")

mlc_churn <-
  mlc_churn %>%
  mutate(
    churned = ifelse(churn == "yes", 1, 0),
    event_time = survival::Surv(account_length, churned)
  ) %>%
  select(event_time, account_length, area_code, total_eve_calls)

set.seed(6941)
churn_rs <- vfold_cv(mlc_churn)

eval_times <- c(50, 100, 150)

churn_rec <-
  recipe(event_time ~ ., data = mlc_churn) %>%
  step_dummy(area_code) %>%
  step_normalize(all_predictors())

tree_spec <-
  decision_tree(cost_complexity = tune(), min_n = 2) %>%
  set_mode("censored regression")

stc_met <- metric_set(concordance_survival)

set.seed(22)
race_stc_res <- tree_spec %>%
  tune_race_anova(
    event_time ~ .,
    resamples = churn_rs,
    grid = tibble(cost_complexity = 10^c(-1.4, -2.5, -3, -5)),
    metrics = stc_met
  )

show_best(race_stc_res, metric = "brier_survival_integrated")
#> Warning: Metric "concordance_survival" was used to evaluate model candidates in the race
#> but "brier_survival_integrated" has been chosen to rank the candidates. These
#> results may not agree with the race.
#> Error in `show_best()`:
#> ! "brier_survival_integrated" was not in the metric set. Please choose
#>   from: "concordance_survival".
#> Backtrace:
#>     ▆
#>  1. ├─tune::show_best(race_stc_res, metric = "brier_survival_integrated")
#>  2. ├─finetune:::show_best.tune_race(race_stc_res, metric = "brier_survival_integrated")
#>  3. ├─base::NextMethod(...)
#>  4. └─tune:::show_best.tune_results(...)
#>  5.   └─tune::choose_metric(x, metric)
#>  6.     └─tune:::check_metric_in_tune_results(mtr_info, metric, call = call)
#>  7.       └─cli::cli_abort(...)
#>  8.         └─rlang::abort(...)

Created on 2024-01-05 with reprex v2.0.2

simonpcouch commented 8 months ago

It'd be nice if we could use tune::check_metric_in_tune_results() right before that warning is raised, but the check_metric_in_tune_results() function isn't exported currently. Should we export it or is it trivial enough to copy?

simonpcouch commented 8 months ago

Max gave a thumbs-up on exporting. On odbc today but will spend some time on this tomorrow if it hasn't been tackled by another person by then. :)