tidymodels / tune

Tools for tidy parameter tuning
https://tune.tidymodels.org
Other
275 stars 42 forks source link

workers error with custom metrics when using socket clusters #937

Open simonpcouch opened 3 weeks ago

simonpcouch commented 3 weeks ago

A revival of #479, reprex whittled down from tidymodels/finetune#116, will also fix tidymodels/yardstick#514 when addressed. When using socket cluster parallelism (notably, this not an issue with forking), workers can't find custom-defined yardstick metrics:

library(tidymodels)

# Logic for `event_level`
event_col <- function(truth, event_level) {
  if (identical(event_level, "first")) {
    levels(truth)[1]
  } else {
    levels(truth)[2]
  }
}

pauc_impl <- function(truth, estimate, estimator = 'binary', event_level) {

  if(estimator == "binary") {

    level_case = event_col(truth = truth, event_level = event_level)
    level_control = setdiff(levels(truth), level_case)

    result = pROC::roc(estimate,
                       response = truth,
                       levels = c(level_control, level_case),
                       partial.auc = c(0.8,1),
                       partial.auc.focus = "sensitivity",
                       direction = "<")

    pauc_value = as.numeric(result$auc)
  }

  return(pauc_value)
}

pauc_vec <- function(truth,
                     estimate,
                     estimator = NULL,
                     na_rm = TRUE,
                     case_weights = NULL,
                     event_level = "first",
                     ...) {
  # calls finalize_estimator_internal() internally
  estimator <- finalize_estimator(truth, estimator, metric_class = "pauc")

  check_prob_metric(truth, estimate, case_weights, estimator)

  if (na_rm) {
    result <- yardstick_remove_missing(truth, estimate, case_weights)

    truth <- result$truth
    estimate <- result$estimate
    case_weights <- result$case_weights
  } else if (yardstick_any_missing(truth, estimate, case_weights)) {
    return(NA_real_)
  }

  pauc_impl(truth, estimate, estimator, event_level)
}

pauc <- function(data, ...) {
  UseMethod("pauc")
}

pauc <- new_prob_metric(pauc, direction = "maximize")

pauc.data.frame <- function(data,
                            truth,
                            estimate,
                            estimator = NULL,
                            na_rm = TRUE,
                            case_weights = NULL,
                            event_level = "first",
                            options = list()) {

  prob_metric_summarizer(
    name = "pauc",
    fn = pauc_vec,
    data = data,
    truth = !!enquo(truth),
    !!enquo(estimate),
    estimator = estimator,
    na_rm = na_rm,
    case_weights = !!enquo(case_weights),
    event_level = event_level,
    fn_options = list(options = options)
  )

}

set.seed(123)
data("lending_club", package = "modeldata")
split <- initial_split(lending_club)
train <- training(split)
test  <- testing(split)
folds <- vfold_cv(data = train, v = 10)

fit_resamples(
  decision_tree("classification"), 
  Class ~ funded_amnt + term, 
  folds, 
  metrics = metric_set(pauc)
)
#> # Resampling results
#> # 10-fold cross-validation 
#> # A tibble: 10 × 4
#>    splits             id     .metrics         .notes          
#>    <list>             <chr>  <list>           <list>          
#>  1 <split [6652/740]> Fold01 <tibble [1 × 4]> <tibble [0 × 4]>
#>  2 <split [6652/740]> Fold02 <tibble [1 × 4]> <tibble [0 × 4]>
#>  3 <split [6653/739]> Fold03 <tibble [1 × 4]> <tibble [0 × 4]>
#>  4 <split [6653/739]> Fold04 <tibble [1 × 4]> <tibble [0 × 4]>
#>  5 <split [6653/739]> Fold05 <tibble [1 × 4]> <tibble [0 × 4]>
#>  6 <split [6653/739]> Fold06 <tibble [1 × 4]> <tibble [0 × 4]>
#>  7 <split [6653/739]> Fold07 <tibble [1 × 4]> <tibble [0 × 4]>
#>  8 <split [6653/739]> Fold08 <tibble [1 × 4]> <tibble [0 × 4]>
#>  9 <split [6653/739]> Fold09 <tibble [1 × 4]> <tibble [0 × 4]>
#> 10 <split [6653/739]> Fold10 <tibble [1 × 4]> <tibble [0 × 4]>

library(future)
plan(multisession)

fit_resamples(
  decision_tree("classification"), 
  Class ~ funded_amnt + term, 
  folds, 
  metrics = metric_set(pauc)
)
#> x Fold01: internal:
#>   Error in `metric_set()`:
#>   ! Failed to compute `pauc()`.
#>   Caused by error in `UseMethod()`:
#>   ! no applicable method for 'pauc' applied to an object of class "c('tb...
#> x Fold02: internal:
#>   Error in `metric_set()`:
#>   ! Failed to compute `pauc()`.
#>   Caused by error in `UseMethod()`:
#>   ! no applicable method for 'pauc' applied to an object of class "c('tb...
#> x Fold03: internal:
#>   Error in `metric_set()`:
#>   ! Failed to compute `pauc()`.
#>   Caused by error in `UseMethod()`:
#>   ! no applicable method for 'pauc' applied to an object of class "c('tb...
#> x Fold04: internal:
#>   Error in `metric_set()`:
#>   ! Failed to compute `pauc()`.
#>   Caused by error in `UseMethod()`:
#>   ! no applicable method for 'pauc' applied to an object of class "c('tb...
#> x Fold05: internal:
#>   Error in `metric_set()`:
#>   ! Failed to compute `pauc()`.
#>   Caused by error in `UseMethod()`:
#>   ! no applicable method for 'pauc' applied to an object of class "c('tb...
#> x Fold06: internal:
#>   Error in `metric_set()`:
#>   ! Failed to compute `pauc()`.
#>   Caused by error in `UseMethod()`:
#>   ! no applicable method for 'pauc' applied to an object of class "c('tb...
#> x Fold07: internal:
#>   Error in `metric_set()`:
#>   ! Failed to compute `pauc()`.
#>   Caused by error in `UseMethod()`:
#>   ! no applicable method for 'pauc' applied to an object of class "c('tb...
#> x Fold08: internal:
#>   Error in `metric_set()`:
#>   ! Failed to compute `pauc()`.
#>   Caused by error in `UseMethod()`:
#>   ! no applicable method for 'pauc' applied to an object of class "c('tb...
#> x Fold09: internal:
#>   Error in `metric_set()`:
#>   ! Failed to compute `pauc()`.
#>   Caused by error in `UseMethod()`:
#>   ! no applicable method for 'pauc' applied to an object of class "c('tb...
#> x Fold10: internal:
#>   Error in `metric_set()`:
#>   ! Failed to compute `pauc()`.
#>   Caused by error in `UseMethod()`:
#>   ! no applicable method for 'pauc' applied to an object of class "c('tb...
#> Warning: All models failed. Run `show_notes(.Last.tune.result)` for more
#> information.
#> # Resampling results
#> # 10-fold cross-validation 
#> # A tibble: 10 × 4
#>    splits             id     .metrics .notes          
#>    <list>             <chr>  <list>   <list>          
#>  1 <split [6652/740]> Fold01 <NULL>   <tibble [1 × 4]>
#>  2 <split [6652/740]> Fold02 <NULL>   <tibble [1 × 4]>
#>  3 <split [6653/739]> Fold03 <NULL>   <tibble [1 × 4]>
#>  4 <split [6653/739]> Fold04 <NULL>   <tibble [1 × 4]>
#>  5 <split [6653/739]> Fold05 <NULL>   <tibble [1 × 4]>
#>  6 <split [6653/739]> Fold06 <NULL>   <tibble [1 × 4]>
#>  7 <split [6653/739]> Fold07 <NULL>   <tibble [1 × 4]>
#>  8 <split [6653/739]> Fold08 <NULL>   <tibble [1 × 4]>
#>  9 <split [6653/739]> Fold09 <NULL>   <tibble [1 × 4]>
#> 10 <split [6653/739]> Fold10 <NULL>   <tibble [1 × 4]>
#> 
#> There were issues with some computations:
#> 
#>   - Error(s) x10: Error in `metric_set()`: ! Failed to compute `pauc()`. Caused by ...
#> 
#> Run `show_notes(.Last.tune.result)` for more information.

Created on 2024-09-05 with reprex v2.1.1

We can fix this by exporting "metrics" as a global, I believe.