tidymodels / tune

Tools for tidy parameter tuning
https://tune.tidymodels.org
Other
282 stars 42 forks source link

resample calibration post-processors with an internal split #894

Closed simonpcouch closed 5 months ago

simonpcouch commented 6 months ago

Related to https://github.com/tidymodels/workflows/pull/225, https://github.com/tidymodels/container/pull/12.

Code looks something like (updated 5/22/2024):

library(tidymodels)
library(tailor)

y <- seq(0, 7, .001)
dat <- data.frame(y = y, x = y + (y-3)^2)

dat

wflow <- 
  workflow(
    y ~ x, 
    boost_tree("regression", trees = 3),
    tailor("regression") %>% adjust_numeric_calibration("linear")
  )

fit_resamples(wflow, vfold_cv(dat))
Previous PR description This PR proposes resampling calibrators using an "internal split"—it's _very_ scrappy at the moment and intended only for internal testing. ``` r library(tidymodels) library(container) library(probably) #> #> Attaching package: 'probably' #> The following objects are masked from 'package:base': #> #> as.factor, as.ordered # create example data set.seed(1) dat <- tibble(y = rnorm(100), x = y/2 + rnorm(100)) dat #> # A tibble: 100 × 2 #> y x #> #> 1 -0.626 -0.934 #> 2 0.184 0.134 #> 3 -0.836 -1.33 #> 4 1.60 0.956 #> 5 0.330 -0.490 #> 6 -0.820 1.36 #> 7 0.487 0.960 #> 8 0.738 1.28 #> 9 0.576 0.672 #> 10 -0.305 1.53 #> # ℹ 90 more rows dat_boots <- bootstraps(dat) # construct workflow wf_simple <- workflow(y ~ x, boost_tree("regression", trees = 3)) # specify calibration reg_ctr <- container(mode = "regression") %>% adjust_numeric_calibration(type = "linear") wf_post <- wf_simple %>% add_container(reg_ctr) # resample workflows set.seed(1) wf_simple_res <- fit_resamples( wf_simple, dat_boots, control = control_grid(save_pred = TRUE) ) set.seed(1) wf_post_res <- fit_resamples( wf_post, dat_boots, control = control_grid(save_pred = TRUE) ) # ...train the post-processor post-hoc cal_manual <- cal_estimate_linear(wf_simple_res, truth = y) cal_manual_preds <- cal_apply(wf_simple_res, cal_manual) simple_preds <- collect_predictions(wf_simple_res, summarize = TRUE) cal_auto_preds <- collect_predictions(wf_post_res, summarize = TRUE) cal_manual_preds #> # A tibble: 100 × 4 #> .pred .row y .config #> #> 1 -0.167 1 -0.626 Preprocessor1_Model1 #> 2 0.267 2 0.184 Preprocessor1_Model1 #> 3 0.215 3 -0.836 Preprocessor1_Model1 #> 4 0.273 4 1.60 Preprocessor1_Model1 #> 5 -0.118 5 0.330 Preprocessor1_Model1 #> 6 0.269 6 -0.820 Preprocessor1_Model1 #> 7 0.140 7 0.487 Preprocessor1_Model1 #> 8 0.219 8 0.738 Preprocessor1_Model1 #> 9 0.254 9 0.576 Preprocessor1_Model1 #> 10 0.0856 10 -0.305 Preprocessor1_Model1 #> # ℹ 90 more rows ``` Averaged predictions from the uncalibrated model: ``` r ggplot(simple_preds, aes(x = y, y = .pred)) + geom_point() ``` ![](https://i.imgur.com/lPAu5VG.png) Averaged predictions from the model calibrated internally in tune: ``` r ggplot(cal_auto_preds, aes(x = y, y = .pred)) + geom_point() ``` ![](https://i.imgur.com/GRx0vvY.png) Averaged predictions from the uncalibrated model, calibrated manually after the fact with probably (I’m not sure I got the flow right with `cal_estimate_linear(...) %>% cal_apply(...)`?): ``` r ggplot(cal_manual_preds, aes(x = y, y = .pred)) + geom_point() ``` ![](https://i.imgur.com/tb9HvRJ.png) Created on 2024-04-26 with [reprex v2.1.0](https://reprex.tidyverse.org) As-is, this PR doesn't apply any postprocessor if there's not a calibrator in the postprocessor--mostly intended to allow for experimentation on the statistical properties of resampling calibrators in this way.
topepo commented 6 months ago

I was thinking a lot about this this morning. Some thoughts not in our google doc:

simonpcouch commented 5 months ago

With an eye for reducing Remotes hoopla, I'm going to go ahead and merge and open issues for smaller todos.

github-actions[bot] commented 5 months ago

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.