tidymodels / workflows

Modeling Workflows
https://workflows.tidymodels.org/
Other
205 stars 23 forks source link

first pass at postprocessing proof-of-concept #225

Closed simonpcouch closed 4 months ago

simonpcouch commented 5 months ago

Mostly just pattern-matches recipe and model_fit implementations and hooks into the existing post-processing infrastructure. See tests for up-to-date examples.

Previous PR description, outdated A fast and loose proof-of-concept for integrating a postprocessing container. ``` r library(workflows) library(parsnip) library(dplyr) #> #> Attaching package: 'dplyr' #> The following objects are masked from 'package:stats': #> #> filter, lag #> The following objects are masked from 'package:base': #> #> intersect, setdiff, setequal, union library(container) library(modeldata) table(credit_data$Status) #> #> bad good #> 1254 3200 wflow_class <- fit(workflow(Status ~ ., logistic_reg()), credit_data) predict(wflow_class, credit_data) %>% table() #> .pred_class #> bad good #> 681 3358 post <- container(mode = "classification", type = "binary") %>% adjust_probability_threshold(.1) wflow_class_container <- workflow(Status ~ ., logistic_reg(), post) wflow_class_container <- fit(wflow_class_container, credit_data) predict(wflow_class_container, credit_data) %>% table() #> .pred_class #> bad good #> 2659 1380 ``` Created on 2024-04-23 with [reprex v2.1.0](https://reprex.tidyverse.org)
simonpcouch commented 5 months ago

On 3.6.3, seeing:

  Error: 
  ! error in pak subprocess
  Caused by error: 
  ! Could not solve package dependencies:
  * deps::.: Can't install dependency tidymodels/container#1
  * tidymodels/container#1: Can't install dependency tidymodels/probably (>= 1.0.3.9000)
  * tidymodels/probably: Can't install dependency tune (>= 1.1.2)
  * tune: Needs R >= 4.0

Related tune line.

simonpcouch commented 5 months ago

I believe the dependency conflicts in GHA installs right now are due to the fact that tidymodels/recipes and tidymodels/parsnip have a tidymodels/hardhat (as in main) Remotes ref and this one has the upstream tidymodels/hardhat#248 ref. I'm not sure that there's a workaround here, though we could just review and merge the hardhat PR tidymodels/hardhat#248 as it's implementation doesn't depend on how postprocessors are implemented.

simonpcouch commented 5 months ago

Turns out this PR doesn't need further changes to support tidymodels/container#12. With that PR and this PR as-is:

library(tidymodels)
library(container)
library(probably)
#> 
#> Attaching package: 'probably'
#> The following objects are masked from 'package:base':
#> 
#>     as.factor, as.ordered

# create example data
set.seed(1)
dat <- tibble(y = rnorm(100), x = y/2 + rnorm(100))

dat
#> # A tibble: 100 × 2
#>         y      x
#>     <dbl>  <dbl>
#>  1 -0.626 -0.934
#>  2  0.184  0.134
#>  3 -0.836 -1.33 
#>  4  1.60   0.956
#>  5  0.330 -0.490
#>  6 -0.820  1.36 
#>  7  0.487  0.960
#>  8  0.738  1.28 
#>  9  0.576  0.672
#> 10 -0.305  1.53 
#> # ℹ 90 more rows

# construct workflow
wf_simple <- workflow(y ~ x, boost_tree("regression", trees = 3))

# specify calibration
reg_ctr <-
  container(mode = "regression") %>%
  adjust_numeric_calibration(type = "linear")

wf_post <- wf_simple %>% add_container(reg_ctr)

# train workflow
wf_simple_fit <- fit(wf_simple, dat)
wf_post_fit <- fit(wf_post, dat)

# predict from both workflows
predict(wf_simple_fit, dat)
#> # A tibble: 100 × 1
#>      .pred
#>      <dbl>
#>  1 -0.223 
#>  2  0.638 
#>  3 -0.218 
#>  4  0.609 
#>  5  0.0739
#>  6  0.119 
#>  7  0.609 
#>  8  0.568 
#>  9  0.378 
#> 10  0.262 
#> # ℹ 90 more rows
predict(wf_post_fit, dat)
#> # A tibble: 100 × 1
#>      .pred
#>      <dbl>
#>  1 -0.719 
#>  2  0.904 
#>  3 -0.701 
#>  4  0.856 
#>  5 -0.258 
#>  6 -0.243 
#>  7  0.856 
#>  8  0.784 
#>  9  0.263 
#> 10 -0.0765
#> # ℹ 90 more rows

# ...verify output is the same as training the post-proc separately
# (fyi we're naughtily re-predicting the training set)
wf_simple_preds <- augment(wf_simple_fit, dat)
calibrator <- cal_estimate_linear(wf_simple_preds, truth = y)
cal_apply(wf_simple_preds, calibrator)
#> # A tibble: 100 × 3
#>      .pred      y      x
#>      <dbl>  <dbl>  <dbl>
#>  1 -0.719  -0.626 -0.934
#>  2  0.904   0.184  0.134
#>  3 -0.701  -0.836 -1.33 
#>  4  0.856   1.60   0.956
#>  5 -0.258   0.330 -0.490
#>  6 -0.243  -0.820  1.36 
#>  7  0.856   0.487  0.960
#>  8  0.784   0.738  1.28 
#>  9  0.263   0.576  0.672
#> 10 -0.0765 -0.305  1.53 
#> # ℹ 90 more rows

Created on 2024-04-26 with reprex v2.1.0

simonpcouch commented 5 months ago

ec0effa surfaces an important point; removing/updating a postprocessor from an otherwise trained workflow need not remove the preprocessor and model fits, as they won't be affected by the removal of the postprocessor. This introduces the possibility of a "partially trained" workflow, where a workflow with trained preprocessor and model but untrained postprocessor should be able to fit without issue.

hfrick commented 5 months ago

After chatting with Max:

simonpcouch commented 4 months ago

With an eye for reducing Remotes hoopla, I'm going to go ahead and merge and open issues for smaller todos.

github-actions[bot] commented 3 months ago

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.