spsanderson / tidyAML

Auto ML for the tidyverse
http://www.spsanderson.com/tidyAML/
Other
64 stars 7 forks source link

`fast_regression()` is producing inforrect workflow objects #27

Closed spsanderson closed 1 year ago

spsanderson commented 1 year ago

The following objects are being produced incorrectly

This stems from using wflw[[1]] in the code and not mapping a function to the list column object.

To fix the wflw column make a helper function like so:

Function:

# Safely make workflow
internal_make_wflw <- function(.model_tbl, .rec_obj){

  # Tidyeval ----
  model_tbl <- .model_tbl
  rec_obj <- .rec_obj

  # Checks ----
  if (!inherits(model_tbl, "tidyaml_mod_spec_tbl")){
    rlang::abort(
      message = "'.model_tbl' must inherit a class of 'tidyaml_mod_spec_tbl",
      use_cli_format = TRUE
    )
  }

  # Manipulation
  model_factor_tbl <- model_tbl %>%
    dplyr::mutate(.model_id = forcats::as_factor(.model_id)) %>%
    dplyr::mutate(rec_obj = list(rec_obj))

  # Make a group split object list
  models_list <- model_factor_tbl %>%
    dplyr::group_split(.model_id)

  # Make the Workflow Object using purrr imap
  wflw_list <- models_list %>%
    purrr::imap(
      .f = function(obj, id){

        # Pull the model column and then pluck the model
        mod <- obj %>% dplyr::pull(5) %>% purrr::pluck(1)

        # PUll the recipe column and then pluck the recipe
        rec_obj <- obj %>% dplyr::pull(6) %>% purrr::pluck(1)

        # Create a safe add_model function
        safe_add_model <- purrr::safely(
          workflows::add_model,
          otherwise = "Error - Could not make workflow object.",
          quiet = FALSE
        )

        # Return the workflow object with recipe and model
        ret <- workflows::workflow() %>%
          workflows::add_recipe(rec_obj) %>%
          safe_add_model(mod)

        # Pluck the result
        res <- ret %>% purrr::pluck("result")

        # Return the result
        return(res)
      }
    )

  # Return
  return(wflw_list)
}

Example:

library(tidyverse)
library(tidymodels)
tidymodels_prefer()

mod_spec_tbl <- fast_regression_parsnip_spec_tbl(
  .parsnip_fns = "linear_reg", 
  .parsnip_eng = c("lm","glm")
)

# A tibble: 2 × 5
  .model_id .parsnip_engine .parsnip_mode .parsnip_fns model_spec
      <int> <chr>           <chr>         <chr>        <list>    
1         1 lm              regression    linear_reg   <spec[+]> 
2         2 glm             regression    linear_reg   <spec[+]> 

# Generate Workflow object
mod_tbl <- mod_spec_tbl %>%
  dplyr::mutate(
    wflw = internal_make_wflw(mod_spec_tbl, .rec_obj = rec_obj)
  )

> mod_tbl
# A tibble: 2 × 6
  .model_id .parsnip_engine .parsnip_mode .parsnip_fns model_spec wflw      
      <int> <chr>           <chr>         <chr>        <list>     <list>    
1         1 lm              regression    linear_reg   <spec[+]>  <workflow>
2         2 glm             regression    linear_reg   <spec[+]>  <workflow>
spsanderson commented 1 year ago

This will make the workflow, something similar should work for fitting:

rec_obj <- recipes::recipe(mpg ~ ., data = mtcars)

tst <- create_model_spec(
  .parsnip_eng = list("gee","lm"),
  .mode = list("regression"),
  .parsnip_fns = list("linear_reg")
)

tst <- tst %>%
  dplyr::mutate(.model_id = dplyr::row_number()) %>%
  dplyr::select(.model_id, dplyr::everything()) %>%
  dplyr::mutate(.model_id = forcats::as_factor(.model_id)) %>%
  dplyr::mutate(rec_obj = list(rec_obj))

models_tbl <- tst %>%
  dplyr::group_split(.model_id)

wflw_list <- models_tbl %>% 
  purrr::imap(
    .f = function(obj, id){

      # Pull the model column and then pluck the model
      mod <-  obj %>% dplyr::pull(5) %>% purrr::pluck(1)

      # Pull the recipe column and then pluck the recipe
      rec_obj <- obj %>% dplyr::pull(6) %>% purrr::pluck(1)

      # Create a safe add model function
      safe_add_mod <- purrr::safely(
        workflows::add_model, 
        otherwise = "Error", 
        quiet = FALSE
      )

      # Return the workflow object with recipe and model
      ret <- workflows::workflow() %>% 
        workflows::add_recipe(rec_obj) %>% 
        safe_add_mod(mod)

      # Pluck the result
      res <- ret %>% purrr::pluck("result")

      # Return the result
      return(res)
    }
  )

Error: parsnip could not locate an implementation for `linear_reg` regression model specifications using the `gee` engine.

> wflw_list
[[1]]
[1] "Error"

[[2]]
══ Workflow ═══════════════════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()

── Preprocessor ───────────────────────────────────────────────────────────────────────────────
0 Recipe Steps

── Model ──────────────────────────────────────────────────────────────────────────────────────
Linear Regression Model Specification (regression)

Computational engine: lm