tidymodels / tune

Tools for tidy parameter tuning
https://tune.tidymodels.org
Other
271 stars 42 forks source link

`augment()` for tuning result of a survival model fails #756

Closed hfrick closed 9 months ago

hfrick commented 9 months ago

Reprex taken from https://github.com/tidymodels/tune/pull/703, potentially to be made more minimal

library(tidymodels)
library(censored)
#> Loading required package: survival

data("mlc_churn")

mlc_churn <-
  mlc_churn %>%
  mutate(
    churned = ifelse(churn == "yes", 1, 0),
    event_time = Surv(account_length, churned)
  ) %>%
  select(event_time, account_length, voice_mail_plan) %>%
  slice(1:500)

set.seed(6941)
churn_split <- initial_split(mlc_churn, prop = 4/5)
churn_tr <- training(churn_split)
churn_te <- testing(churn_split)
churn_rs <- bootstraps(churn_tr, times = 2)

eval_times <- c(100, 10, 150)
event_metrics <- metric_set(brier_survival, brier_survival_integrated,
                            concordance_survival, roc_auc_survival)
sr_tune_spec <- survival_reg(dist = tune())
poly_rec <- recipe(event_time ~ ., data = churn_tr) %>%
  step_poly(account_length, degree = tune())
sr_tune_wflow <-
  workflow() %>%
  add_model(sr_tune_spec) %>%
  add_recipe(poly_rec)

sr_tune_param <-
  sr_tune_wflow %>%
  extract_parameter_set_dials() %>%
  update(
    degree = degree(c(1, 10)),
    dist = surv_dist(c(c("loglogistic", "lognormal")))
  )

set.seed(1)
sr_tune_res <-
  sr_tune_wflow %>%
  tune_bayes(
    resamples = churn_rs,
    metrics = event_metrics,
    eval_time = eval_times,
    param_info = sr_tune_param,
    initial = 5,
    iter = 2,
    control = control_bayes(save_pred = TRUE)
  )
#> → A | warning: Ran out of iterations and did not converge
#> There were issues with some computations   A: x1
#> There were issues with some computations   A: x3
#> There were issues with some computations   A: x3
#> 

grid_settings <-
  tibble::tribble(
    ~dist,          ~degree, ~.config,
    "loglogistic", 5.86595467322692,  "Iter1"
  )

sr_rs_logn_aug <- augment(sr_tune_res, parameters = grid_settings)
#> Error in `dplyr::summarize()`:
#> ! Can't subset columns that don't exist.
#> ✖ Column `.eval_time` doesn't exist.
#> Backtrace:
#>      ▆
#>   1. ├─generics::augment(sr_tune_res, parameters = grid_settings)
#>   2. ├─tune:::augment.tune_results(sr_tune_res, parameters = grid_settings)
#>   3. │ ├─tune::collect_predictions(x, summarize = TRUE, parameters = parameters)
#>   4. │ └─tune:::collect_predictions.tune_results(...)
#>   5. │   └─tune:::average_predictions(x, parameters)
#>   6. │     └─tune:::surv_summarize(x, param_names, y_nms)
#>   7. │       └─... %>% dplyr::relocate(.pred)
#>   8. ├─dplyr::relocate(., .pred)
#>   9. ├─tidyr::nest(...)
#>  10. ├─dplyr::summarize(...)
#>  11. ├─dplyr:::summarise.data.frame(...)
#>  12. │ └─dplyr:::compute_by(...)
#>  13. │   └─dplyr:::eval_select_by(by, data, error_call = error_call)
#>  14. │     └─tidyselect::eval_select(...)
#>  15. │       └─tidyselect:::eval_select_impl(...)
#>  16. │         ├─tidyselect:::with_subscript_errors(...)
#>  17. │         │ └─rlang::try_fetch(...)
#>  18. │         │   └─base::withCallingHandlers(...)
#>  19. │         └─tidyselect:::vars_select_eval(...)
#>  20. │           └─tidyselect:::walk_data_tree(expr, data_mask, context_mask)
#>  21. │             └─tidyselect:::eval_c(expr, data_mask, context_mask)
#>  22. │               └─tidyselect:::reduce_sels(node, data_mask, context_mask, init = init)
#>  23. │                 └─tidyselect:::walk_data_tree(new, data_mask, context_mask)
#>  24. │                   └─tidyselect:::as_indices_sel_impl(...)
#>  25. │                     └─tidyselect:::as_indices_impl(...)
#>  26. │                       └─tidyselect:::chr_as_locations(x, vars, call = call, arg = arg)
#>  27. │                         └─vctrs::vec_as_location(...)
#>  28. └─vctrs (local) `<fn>`()
#>  29.   └─vctrs:::stop_subscript_oob(...)
#>  30.     └─vctrs:::stop_subscript(...)
#>  31.       └─rlang::abort(...)

Created on 2023-11-14 with reprex v2.0.2

EmilHvitfeldt commented 9 months ago

So I don't think this is a survival issue. While that being said, I think we could spin this into some other issues that would be worth working on as the error messages are not great

I think the issue is happening because of precision in degree. A degree value of 5.86595467322692 doesn't exists.

collect_predictions(sr_tune_res)$degree |> unique()
#> [1] 5.045940 6.856286 9.842989 1.587631 4.243622 8.608154 5.873989

If you pick a valid combination then things work

grid_settings <-
  tibble::tribble(
    ~dist,          ~degree, ~.config,
    "loglogistic", 5.86595467322692,  "Iter1"
  )

collect_predictions(sr_tune_res) %>%
  inner_join(grid_settings)
#> Joining with `by = join_by(degree, dist, .config)`
#> # A tibble: 0 × 9
#> # ℹ 9 variables: id <chr>, .pred <list>, .row <int>, degree <dbl>, dist <chr>,
#> #   .pred_time <dbl>, event_time <Surv>, .config <chr>, .iter <int>

grid_settings <- collect_predictions(sr_tune_res) %>%
  select(dist, degree, .config) %>%
  slice(1)

grid_settings
#> # A tibble: 1 × 3
#>   dist        degree .config             
#>   <chr>        <dbl> <chr>               
#> 1 loglogistic   5.05 Preprocessor1_Model1

collect_predictions(sr_tune_res) %>%
  inner_join(grid_settings)
#> Joining with `by = join_by(degree, dist, .config)`
#> # A tibble: 289 × 9
#>    id         .pred     .row degree dist     .pred_time event_time .config .iter
#>    <chr>      <list>   <int>  <dbl> <chr>         <dbl>     <Surv> <chr>   <int>
#>  1 Bootstrap1 <tibble>     2   5.05 loglogi…       96.4        96+ Prepro…     0
#>  2 Bootstrap1 <tibble>     3   5.05 loglogi…       62.7        62+ Prepro…     0
#>  3 Bootstrap1 <tibble>     8   5.05 loglogi…      118.        117+ Prepro…     0
#>  4 Bootstrap1 <tibble>    14   5.05 loglogi…      293.        196+ Prepro…     0
#>  5 Bootstrap1 <tibble>    15   5.05 loglogi…       39.2        39+ Prepro…     0
#>  6 Bootstrap1 <tibble>    17   5.05 loglogi…       72.8        72+ Prepro…     0
#>  7 Bootstrap1 <tibble>    19   5.05 loglogi…       97.4        97+ Prepro…     0
#>  8 Bootstrap1 <tibble>    21   5.05 loglogi…       49.6        49+ Prepro…     0
#>  9 Bootstrap1 <tibble>    29   5.05 loglogi…       91.3        91+ Prepro…     0
#> 10 Bootstrap1 <tibble>    34   5.05 loglogi…       41.5        41+ Prepro…     0
#> # ℹ 279 more rows

result of augment()

augment(sr_tune_res, parameters = grid_settings)
#> Warning: The original data had 400 rows but there were 232 hold-out
#> predictions.
#> # A tibble: 400 × 5
#>    event_time account_length voice_mail_plan .pred            .pred_time
#>        <Surv>          <int> <fct>           <list>                <dbl>
#>  1       149+            149 yes             <NULL>                 NA  
#>  2        96+             96 no              <tibble [3 × 3]>       96.4
#>  3        62+             62 no              <tibble [3 × 3]>       62.7
#>  4        58+             58 yes             <NULL>                 NA  
#>  5        62+             62 no              <NULL>                 NA  
#>  6       114+            114 yes             <NULL>                 NA  
#>  7        72              72 yes             <tibble [3 × 3]>       74.1
#>  8       117+            117 yes             <tibble [3 × 3]>      118. 
#>  9        96+             96 yes             <tibble [3 × 3]>       97.8
#> 10        62              62 no              <tibble [3 × 3]>       63.2
#> # ℹ 390 more rows
hfrick commented 9 months ago

Great sleuthing! I agree that this could be handled better and have opened that other issue: #765

github-actions[bot] commented 9 months ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.