tidymodels / workflowsets

Create a collection of modeling workflows
https://workflowsets.tidymodels.org/
Other
88 stars 8 forks source link

Predict on new_data from a workflow_set #111

Closed gsimchoni closed 1 year ago

gsimchoni commented 1 year ago

Feature

Once a set of workflows in a workflow_set have been fitted, e.g. by workflow_map(), it would be nice to be able to predict() from these workflows to new_datas. Currently I'm not sure how to do this, tell me if I'm missing an obvious solution.

I would suggest using lm_models %>% workflow_map("predict", new_data = ames_test), see below.

Reprex

library(tidyverse)
library(tidymodels)

# get train, test, folds object
ames_split <- initial_split(ames, prop = 0.80)
ames_train <- training(ames_split)
ames_test  <-  testing(ames_split)
ames_folds <- vfold_cv(ames_train, v = 10)

basic_rec <- 
  recipe(Sale_Price ~ Neighborhood + Gr_Liv_Area + Year_Built + Bldg_Type + 
           Latitude + Longitude, data = ames_train) %>%
  step_log(Gr_Liv_Area, base = 10) %>% 
  step_other(Neighborhood, threshold = 0.01) %>% 
  step_dummy(all_nominal_predictors())

interaction_rec <- 
  basic_rec %>% 
  step_interact( ~ Gr_Liv_Area:starts_with("Bldg_Type_") ) 

spline_rec <- 
  interaction_rec %>% 
  step_ns(Latitude, Longitude, deg_free = 50)

preproc <- 
  list(basic = basic_rec, 
       interact = interaction_rec, 
       splines = spline_rec
  )

lm_models <- workflow_set(preproc, list(lm = linear_reg()), cross = FALSE)

lm_models <-
  lm_models %>% 
  workflow_map("fit_resamples", resamples = ames_folds)

Now would like to use either:


# predict(lm_models, new_data = ames_test)
# or
# lm_models %>% 
#    workflow_map("predict", new_data = ames_test)

(though I can see why we wouldn't want a predict.workflow_set() function)

simonpcouch commented 1 year ago

Thanks for the issue!

Could I ask what you'd hope to do with that functionality?

If you’re looking to evaluate performance of each model with the goal of ultimately choosing one, you may want to check out some of the exploratory methods in the package which are based on assessment set predictions:

autoplot(lm_models)


collect_metrics(lm_models)
#> # A tibble: 6 × 9
#>   wflow_id    .config        preproc model .metric .esti…¹    mean     n std_err
#>   <chr>       <chr>          <chr>   <chr> <chr>   <chr>     <dbl> <int>   <dbl>
#> 1 basic_lm    Preprocessor1… recipe  line… rmse    standa… 3.90e+4    10 1.11e+3
#> 2 basic_lm    Preprocessor1… recipe  line… rsq     standa… 7.61e-1    10 8.03e-3
#> 3 interact_lm Preprocessor1… recipe  line… rmse    standa… 3.87e+4    10 1.16e+3
#> 4 interact_lm Preprocessor1… recipe  line… rsq     standa… 7.65e-1    10 8.62e-3
#> 5 splines_lm  Preprocessor1… recipe  line… rmse    standa… 3.63e+4    10 1.23e+3
#> 6 splines_lm  Preprocessor1… recipe  line… rsq     standa… 7.93e-1    10 1.13e-2
#> # … with abbreviated variable name ¹​.estimator

collect_predictions(lm_models)
#> # A tibble: 7,032 × 7
#>    wflow_id .config              preproc model       .row Sale_Price   .pred
#>    <chr>    <chr>                <chr>   <chr>      <int>      <int>   <dbl>
#>  1 basic_lm Preprocessor1_Model1 recipe  linear_reg     1     189000 205612.
#>  2 basic_lm Preprocessor1_Model1 recipe  linear_reg     2     157900 146219.
#>  3 basic_lm Preprocessor1_Model1 recipe  linear_reg     3     174500 199241.
#>  4 basic_lm Preprocessor1_Model1 recipe  linear_reg     4     119000 131733.
#>  5 basic_lm Preprocessor1_Model1 recipe  linear_reg     5     137500 196413.
#>  6 basic_lm Preprocessor1_Model1 recipe  linear_reg     6     146000 154397.
#>  7 basic_lm Preprocessor1_Model1 recipe  linear_reg     7     114000 115674.
#>  8 basic_lm Preprocessor1_Model1 recipe  linear_reg     8     212000 218955.
#>  9 basic_lm Preprocessor1_Model1 recipe  linear_reg     9     311500 222989.
#> 10 basic_lm Preprocessor1_Model1 recipe  linear_reg    10      84900  54651.
#> # … with 7,022 more rows

rank_results(lm_models)
#> # A tibble: 6 × 9
#>   wflow_id    .config          .metric    mean std_err     n prepr…¹ model  rank
#>   <chr>       <chr>            <chr>     <dbl>   <dbl> <int> <chr>   <chr> <int>
#> 1 splines_lm  Preprocessor1_M… rmse    3.63e+4 1.23e+3    10 recipe  line…     1
#> 2 splines_lm  Preprocessor1_M… rsq     7.93e-1 1.13e-2    10 recipe  line…     1
#> 3 interact_lm Preprocessor1_M… rmse    3.87e+4 1.16e+3    10 recipe  line…     2
#> 4 interact_lm Preprocessor1_M… rsq     7.65e-1 8.62e-3    10 recipe  line…     2
#> 5 basic_lm    Preprocessor1_M… rmse    3.90e+4 1.11e+3    10 recipe  line…     3
#> 6 basic_lm    Preprocessor1_M… rsq     7.61e-1 8.03e-3    10 recipe  line…     3
#> # … with abbreviated variable name ¹​preprocessor

These methods supply the needed information to select the best model. Making use of the test set to evaluate each fit and select a final one to train on the full training set means that you would have no data left to evaluate performance of that final model with—if you were to use the test set again, this would lead to data leakage.

The above issue is why we haven't opened up that interface; we'd be hesitant to make that workflow feel too easy/natural.

gsimchoni commented 1 year ago

Thank you for this.

Making use of the test set to evaluate each fit and select a final one to train on the full training set means that you would have no data left to evaluate performance of that final model with—if you were to use the test set again, this would lead to data leakage.

I see what you mean. In that case, would it not make sense to be able to choose a single workflow from a workflow_set, finalize it and use it for prediction on a final, untouched, test data? How would one go about it?

simonpcouch commented 1 year ago

That workflow you just proposed sounds right! As long as the choice of a single workflow is based on the estimates from the assessment set. In that case, the test data isn't used at all in the model's development--right at the end, we just use it to get a solid estimate of our performance. :)

gsimchoni commented 1 year ago

Exactly, as above aimes_test isn't used during modeling, it is called once in the end.

simonpcouch commented 1 year ago

A procedure a la

lm_models %>% 
    workflow_map("predict", new_data = ames_test)

would access ames_test as many times as there are models in lm_models.

github-actions[bot] commented 1 year ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.