tidymodels / workflowsets

Create a collection of modeling workflows
https://workflowsets.tidymodels.org/
Other
92 stars 10 forks source link

FR: Combine multiple workflowsets #93

Closed exsell-jc closed 2 years ago

exsell-jc commented 2 years ago

To cite one example, the library modeltime requires a date column, while other algorithms cannot handle the date column. It would be really nice if I can just combine these two workflowsets something along the lines of wfs_total = wfs_models |> cbind(wfs_modeltime) or wfs_total = wfs_models + wfs_modeltime or wfs_total = wfs_models |> combine_workflowsets(wfs_modeltime) etc.

wfs_models = workflow_set(models = list(glmnet = linear_reg_glmnet_spec,
                                        svm = svm_poly_kernlab_spec,
                                        tree_xgboost = boost_tree_xgboost_spec),
                          preproc = list(rec_formula,
                                         rec_formula_norm),
                          cross = T)

wfs_modeltime = workflow_set(models = list(arima = arima_boost_arima_xgboost_spec,
                                           prophet = prophet_boost_prophet_xgboost_spec),
                             preproc = list(rec_date,
                                            rec_date_norm),
                             cross = T)
simonpcouch commented 2 years ago

Thanks for the issue!

Can you please supply a reprex and some explanation of your use case? What is the desired behavior of the combined workflow set?

exsell-jc commented 2 years ago

Thanks for the issue!

Can you please supply a reprex and some explanation of your use case? What is the desired behavior of the combined workflow set?

Thanks for the response. Here it is below:

library(tidytable)
library(tidymodels)
library(lubridate)
library(workflowsets)
library(rsample)
library(recipes)
library(modeltime)

### Sample data + resampling
df = data.frame(yearr = sample(2015:2021, 2000, replace = TRUE),
                monthh = sample(1:12, 2000, replace = TRUE),
                dayy = sample(1:29, 2000, replace = TRUE)) |>
  mutate.(datee = ymd(paste(yearr, monthh, dayy)),
         yy = sample(0:100, 2000, replace = TRUE) + (130 * yearr) + (2 * monthh)) |>
  filter.(!is.na(datee)) |>
  arrange.(datee) |>
  mutate.(ii = row_number.())

### Graph what it looks like
if (F) {
  df |>
    select.(datee, yy) |>
    ggplot() +
    geom_point(aes(y = yy,
                   x = datee),
               colour = 'red')
}

resamples_crossfolds = df |>
  vfold_cv(times = 3, repeats = 1)

### Recipe specifications
rec_formula = df |>
  recipe(yy ~ .) |>
  #update_role(datee, new_role = 'date') |> # THIS IS THE DIFFERENCE
  step_zv(all_predictors())

rec_formula_no_date = df |>
  recipe(yy ~ .) |>
  update_role(datee, new_role = 'date') |> # THIS IS THE DIFFERENCE
  step_zv(all_predictors())

#step_ # https://www.tmwr.org/pre-proc-table.html

### Modelling algorithms for rec_formula
arima_reg_arima_spec = arima_reg(seasonal_period = tune(),
                                 non_seasonal_ar = tune(),
                                 non_seasonal_differences = tune(),
                                 non_seasonal_ma = tune(),
                                 seasonal_ar = tune(),
                                 seasonal_differences = tune(),
                                 seasonal_ma = tune()) |>
  set_engine('arima')

exp_smoothing_ets_spec = exp_smoothing(seasonal_period = tune(),
                           error = tune(),
                           trend = tune(),
                           season = tune(),
                           damping = tune(),
                           smooth_level = tune(),
                           smooth_trend = tune(),
                           smooth_seasonal = tune()) |>
  set_engine('ets')

seasonal_reg_stlm_arima_spec = seasonal_reg(seasonal_period_1 = tune(),
                                            seasonal_period_2 = tune(),
                                            seasonal_period_3 = tune()) |>
  set_engine('stlm_arima')

seasonal_reg_stlm_ets_spec = seasonal_reg(seasonal_period_1 = tune(),
                                          seasonal_period_2 = tune(),
                                          seasonal_period_3 = tune()) |>
  set_engine('stlm_ets')

### Modelling algorithms for rec_formula_no_date
linear_reg_glmnet_spec = linear_reg(penalty = tune(),
                                    mixture = tune()) |>
  set_engine('glmnet')

svm_poly_kernlab_spec = svm_poly(cost = tune(),
                                 degree = tune(),
                                 scale_factor = tune(),
                                 margin = tune()) |>
  set_engine('kernlab') |>
  set_mode('regression')

#parsnip_addin() # Other available algorithms (GPU-usage is a different story)

### Set up for rec_formula_no_date
wfs_models = workflow_set(models = list(arima = arima_reg_arima_spec,
                                        ets = seasonal_reg_stlm_ets_spec),
                          preproc = list(rec_formula),
                          cross = T)

### Set up for rec_formula
wfs_models_no_date = workflow_set(models = list(glmnet = linear_reg_glmnet_spec,
                                                svm = svm_poly_kernlab_spec),
                          preproc = list(rec_formula_no_date),
                          cross = T)

### Wrong set up
wfs_models_wrong = workflow_set(models = list(glmnet = linear_reg_glmnet_spec,
                                              svm = svm_poly_kernlab_spec),
                                preproc = list(rec_formula),
                                cross = T)

### Run the models
tune_models = wfs_models |>
  workflow_map('tune_grid',
               resamples = resamples_crossfolds,
               grid = 3,
               metrics = metric_set(huber_loss),
               control = control_resamples(save_pred = T))

tune_models_no_date = wfs_models_no_date |>
  workflow_map('tune_grid',
               resamples = resamples_crossfolds,
               grid = 3,
               metrics = metric_set(huber_loss),
               control = control_resamples(save_pred = T))

tune_models_wrong = wfs_models_wrong |>
  workflow_map('tune_grid',
               resamples = resamples_crossfolds,
               grid = 3,
               metrics = metric_set(huber_loss),
               control = control_resamples(save_pred = T))

As you can see above, due to the date column, e.g. glmnet will not work. This may be more of a fault of how the algorithms were coded (e.g. instead of date, y/m/d arguments can be included in the scope of the function, or just automatically remove the date column), but since it's not really feasible to change all of those algorithms individually, it's probably best to provide the ability to combine multiple workflowsets. Essentially, I would like to combine tune_models and tune_models_no_date. They have been crossed (cross = T) which makes things very convenient, but a bit more flexibility would really help with faster deployment of multiple workflows.

To continue with the previous code:

### Have to make separate plots
autoplot(tune_models)

autoplot(tune_models_no_date)

### Have to rank separately
best_model = rank_results(tune_models,
                          rank_metric = 'huber_loss',
                          select_best = T)

best_model_no_date = rank_results(tune_models_no_date,
                          rank_metric = 'huber_loss',
                          select_best = T)

### Fits separately... though I guess they can be combined using rbind()? Unsure if that would result in loss of certain features within the object
best_fits = collect_predictions(tune_models)

best_fits_no_date = collect_predictions(tune_models_no_date)

### etc.
df_yy_hat = best_model |>
  filter.(rank == 1) |>
  inner_join.(best_fits) |>
  select.(.pred) |>
  rename.(yy_hat = .pred)

df_yy_hat = best_model_no_date |>
  filter.(rank == 1) |>
  inner_join.(best_fits_no_date ) |>
  select.(.pred) |>
  rename.(yy_hat = .pred)
simonpcouch commented 2 years ago

Thanks for putting this together! I think I understand what you mean more clearly now.

It's my sense that these workflows can live in the same workflowset just fine—you'll just need to make sure you pass the correct recipes to the correct model specifications. Defining your "combined" workflowset would look like:

wfs_models = workflow_set(models = list(arima = arima_reg_arima_spec,
                                        ets = seasonal_reg_stlm_ets_spec,
                                        glmnet = linear_reg_glmnet_spec,
                                        svm = svm_poly_kernlab_spec),
                          preproc = list(rec_formula,
                                         rec_formula,
                                         rec_formula_no_date,
                                         rec_formula_no_date),
                          cross = F)

Thanks again for the issue!

github-actions[bot] commented 2 years ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.