tidymodels / yardstick

Tidy methods for measuring model performance
https://yardstick.tidymodels.org/
Other
367 stars 54 forks source link

`brier_class()` does not work when using `frequency_weights()` #438

Closed zkrog closed 8 months ago

zkrog commented 1 year ago

The problem

I'm not able to use the brier_class() class probability metric in a model workflow that contains frequency_weights().

Reproducible example

library(tidymodels)
#> Warning: package 'tidymodels' was built under R version 4.2.3
#> Warning: package 'broom' was built under R version 4.2.3
#> Warning: package 'dials' was built under R version 4.2.3
#> Warning: package 'scales' was built under R version 4.2.2
#> Warning: package 'dplyr' was built under R version 4.2.3
#> Warning: package 'ggplot2' was built under R version 4.2.3
#> Warning: package 'infer' was built under R version 4.2.2
#> Warning: package 'modeldata' was built under R version 4.2.2
#> Warning: package 'parsnip' was built under R version 4.2.3
#> Warning: package 'purrr' was built under R version 4.2.2
#> Warning: package 'recipes' was built under R version 4.2.3
#> Warning: package 'rsample' was built under R version 4.2.2
#> Warning: package 'tibble' was built under R version 4.2.3
#> Warning: package 'tidyr' was built under R version 4.2.2
#> Warning: package 'tune' was built under R version 4.2.3
#> Warning: package 'workflows' was built under R version 4.2.2
#> Warning: package 'workflowsets' was built under R version 4.2.3
#> Warning: package 'yardstick' was built under R version 4.2.3

mtcars_wts <-
  mtcars |> 
  mutate(freq_wt = frequency_weights(cyl),
         am = factor(am))

lr_spec <- 
  logistic_reg()

lr_rec <-
  recipe(am ~ mpg + freq_wt,
         data = mtcars_wts)

lr_wf <- 
  workflow() |>
  add_model(lr_spec) |> 
  add_recipe(lr_rec) |> 
  add_case_weights(freq_wt)

lr_fit <-
  fit(lr_wf, mtcars_wts)

mtcars_wts |> 
  bind_cols(predict(lr_fit, mtcars_wts, type = 'prob')) |> 
  brier_class(truth = am, .pred_1, case_weights = freq_wt)
#> Error in `brier_class()`:
#> ! `vec_math.hardhat_frequency_weights()` not implemented.
#> Backtrace:
#>      ▆
#>   1. ├─yardstick::brier_class(...)
#>   2. └─yardstick:::brier_class.data.frame(...)
#>   3.   └─yardstick::prob_metric_summarizer(...)
#>   4.     ├─rlang::inject(...)
#>   5.     ├─base::withCallingHandlers(...)
#>   6.     └─yardstick (local) fn(...)
#>   7.       └─yardstick:::brier_class_estimator_impl(...)
#>   8.         └─yardstick:::brier_factor(...)
#>   9.           └─yardstick:::brier_ind(inds, estimate, case_weights)
#>  10.             └─vctrs:::Math.vctrs_vctr(case_weights)
#>  11.               ├─vctrs::vec_math(.Generic, x, ...)
#>  12.               └─vctrs:::vec_math.default(.Generic, x, ...)
#>  13.                 └─vctrs:::stop_unimplemented(.x, "vec_math")
#>  14.                   └─vctrs:::stop_vctrs(...)
#>  15.                     └─rlang::abort(message, class = c(class, "vctrs_error"), ..., call = call)

Created on 2023-07-12 with reprex v2.0.2

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.2.1 (2022-06-23 ucrt) #> os Windows 10 x64 (build 19044) #> system x86_64, mingw32 #> ui RTerm #> language (EN) #> collate English_United States.utf8 #> ctype English_United States.utf8 #> tz America/New_York #> date 2023-07-12 #> pandoc 2.19.2 @ C:/Program Files/RStudio/bin/quarto/bin/tools/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date (UTC) lib source #> backports 1.4.1 2021-12-13 [1] CRAN (R 4.2.0) #> broom * 1.0.5 2023-06-09 [1] CRAN (R 4.2.3) #> class 7.3-20 2022-01-16 [2] CRAN (R 4.2.1) #> cli 3.6.1 2023-03-23 [1] CRAN (R 4.2.3) #> codetools 0.2-18 2020-11-04 [2] CRAN (R 4.2.1) #> colorspace 2.1-0 2023-01-23 [1] CRAN (R 4.2.2) #> data.table 1.14.8 2023-02-17 [1] CRAN (R 4.2.3) #> dials * 1.2.0 2023-04-03 [1] CRAN (R 4.2.3) #> DiceDesign 1.9 2021-02-13 [1] CRAN (R 4.2.2) #> digest 0.6.33 2023-07-07 [1] CRAN (R 4.2.3) #> dplyr * 1.1.2 2023-04-20 [1] CRAN (R 4.2.3) #> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.2.2) #> evaluate 0.21 2023-05-05 [1] CRAN (R 4.2.3) #> fansi 1.0.4 2023-01-22 [1] CRAN (R 4.2.2) #> fastmap 1.1.1 2023-02-24 [1] CRAN (R 4.2.3) #> foreach 1.5.2 2022-02-02 [1] CRAN (R 4.2.2) #> fs 1.6.2 2023-04-25 [1] CRAN (R 4.2.3) #> furrr 0.3.1 2022-08-15 [1] CRAN (R 4.2.2) #> future 1.33.0 2023-07-01 [1] CRAN (R 4.2.1) #> future.apply 1.11.0 2023-05-21 [1] CRAN (R 4.2.3) #> generics 0.1.3 2022-07-05 [1] CRAN (R 4.2.2) #> ggplot2 * 3.4.2 2023-04-03 [1] CRAN (R 4.2.3) #> globals 0.16.2 2022-11-21 [1] CRAN (R 4.2.2) #> glue 1.6.2 2022-02-24 [1] CRAN (R 4.2.2) #> gower 1.0.1 2022-12-22 [1] CRAN (R 4.2.2) #> GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.2.2) #> gtable 0.3.3 2023-03-21 [1] CRAN (R 4.2.3) #> hardhat 1.3.0 2023-03-30 [1] CRAN (R 4.2.3) #> htmltools 0.5.5 2023-03-23 [1] CRAN (R 4.2.3) #> infer * 1.0.4 2022-12-02 [1] CRAN (R 4.2.2) #> ipred 0.9-14 2023-03-09 [1] CRAN (R 4.2.2) #> iterators 1.0.14 2022-02-05 [1] CRAN (R 4.2.2) #> knitr 1.43 2023-05-25 [1] CRAN (R 4.2.3) #> lattice 0.20-45 2021-09-22 [2] CRAN (R 4.2.1) #> lava 1.7.2.1 2023-02-27 [1] CRAN (R 4.2.2) #> lhs 1.1.6 2022-12-17 [1] CRAN (R 4.2.2) #> lifecycle 1.0.3 2022-10-07 [1] CRAN (R 4.2.2) #> listenv 0.9.0 2022-12-16 [1] CRAN (R 4.2.2) #> lubridate 1.9.2 2023-02-10 [1] CRAN (R 4.2.3) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.2.2) #> MASS 7.3-57 2022-04-22 [2] CRAN (R 4.2.1) #> Matrix 1.6-0 2023-07-08 [1] CRAN (R 4.2.3) #> modeldata * 1.1.0 2023-01-25 [1] CRAN (R 4.2.2) #> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.2.2) #> nnet 7.3-17 2022-01-16 [2] CRAN (R 4.2.1) #> parallelly 1.36.0 2023-05-26 [1] CRAN (R 4.2.3) #> parsnip * 1.1.0 2023-04-12 [1] CRAN (R 4.2.3) #> pillar 1.9.0 2023-03-22 [1] CRAN (R 4.2.3) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.2.2) #> prodlim 2023.03.31 2023-04-02 [1] CRAN (R 4.2.3) #> purrr * 1.0.1 2023-01-10 [1] CRAN (R 4.2.2) #> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.2.3) #> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.2.2) #> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.2.2) #> R.utils 2.12.2 2022-11-11 [1] CRAN (R 4.2.3) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.2.2) #> Rcpp 1.0.11 2023-07-06 [1] CRAN (R 4.2.1) #> recipes * 1.0.6 2023-04-25 [1] CRAN (R 4.2.3) #> reprex 2.0.2 2022-08-17 [1] CRAN (R 4.2.2) #> rlang 1.1.1 2023-04-28 [1] CRAN (R 4.2.3) #> rmarkdown 2.23 2023-07-01 [1] CRAN (R 4.2.1) #> rpart 4.1.16 2022-01-24 [2] CRAN (R 4.2.1) #> rsample * 1.1.1 2022-12-07 [1] CRAN (R 4.2.2) #> rstudioapi 0.15.0 2023-07-07 [1] CRAN (R 4.2.3) #> scales * 1.2.1 2022-08-20 [1] CRAN (R 4.2.2) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.2.3) #> styler 1.10.1 2023-06-05 [1] CRAN (R 4.2.3) #> survival 3.3-1 2022-03-03 [2] CRAN (R 4.2.1) #> tibble * 3.2.1 2023-03-20 [1] CRAN (R 4.2.3) #> tidymodels * 1.1.0 2023-05-01 [1] CRAN (R 4.2.3) #> tidyr * 1.3.0 2023-01-24 [1] CRAN (R 4.2.2) #> tidyselect 1.2.0 2022-10-10 [1] CRAN (R 4.2.2) #> timechange 0.2.0 2023-01-11 [1] CRAN (R 4.2.2) #> timeDate 4022.108 2023-01-07 [1] CRAN (R 4.2.2) #> tune * 1.1.1 2023-04-11 [1] CRAN (R 4.2.3) #> utf8 1.2.3 2023-01-31 [1] CRAN (R 4.2.3) #> vctrs 0.6.3 2023-06-14 [1] CRAN (R 4.2.3) #> withr 2.5.0 2022-03-03 [1] CRAN (R 4.2.2) #> workflows * 1.1.3 2023-02-22 [1] CRAN (R 4.2.2) #> workflowsets * 1.0.1 2023-04-06 [1] CRAN (R 4.2.3) #> xfun 0.39 2023-04-20 [1] CRAN (R 4.2.3) #> yaml 2.3.7 2023-01-23 [1] CRAN (R 4.2.2) #> yardstick * 1.2.0 2023-04-21 [1] CRAN (R 4.2.3) #> #> ────────────────────────────────────────────────────────────────────────────── ```

It appears the issue stems from the normalization procedure in yardstick:::brier_ind() because vector math is not implemented in frequency weights:

https://github.com/tidymodels/yardstick/blob/be3ca2352acdaa35ca2b2c9ab7b53750406d844a/R/prob-brier_class.R#L145

I can implement a temporary fix with:

trace(yardstick:::brier_ind, edit = TRUE)

and changing

case_weights <- exp(case_weights) / sum(exp(case_weights))

to

case_weights <- exp(as.numeric(case_weights)) / sum(exp(as.numeric(case_weights)))

However, I'm sure that is not the best way to solve the problem. And this in an incomplete fix for my "real world" issue, where I am performing grid tuning in parallel.

For example, this works, after making the above change using trace():

resamples <- vfold_cv(mtcars_wts)
fits <- fit_resamples(lr_wf,
                      resamples = resamples,
                      metrics = metric_set(brier_class))

But this will not work:

doParallel::registerDoParallel(4)

resamples <- vfold_cv(mtcars_wts)
fits <- fit_resamples(lr_wf,
                      resamples = resamples,
                      metrics = metric_set(brier_class))
EmilHvitfeldt commented 1 year ago

Thanks for reporting! More minimal reprex provided below:

library(yardstick)
library(dplyr)

mtcars_wts <- mtcars |> 
  mutate(freq_wt = hardhat::frequency_weights(cyl),
         am = factor(am))

mtcars_wts |>
  brier_class(am, disp, case_weights = freq_wt)
#> Error in `brier_class()`:
#> ! `vec_math.hardhat_frequency_weights()` not implemented.
EmilHvitfeldt commented 8 months ago

This has been fixed in https://github.com/tidymodels/yardstick/pull/476

github-actions[bot] commented 7 months ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.