spsanderson / tidyAML

Auto ML for the tidyverse
http://www.spsanderson.com/tidyAML/
Other
63 stars 7 forks source link

Update `internal_make_spec_tbl()` function #157

Closed spsanderson closed 8 months ago

spsanderson commented 8 months ago

Update the internal_make_spec_tbl() in order to add an extra class to the parsnip model specification that is a combination of the parsnip engine and the parsnip function.

Currently the output and class output is as follows:

internal_make_spec_tbl <- function(.model_tbl){

  # Tidyeval ----
  model_tbl <- .model_tbl

  # Checks ----
  if (!inherits(model_tbl, "tidyaml_base_tbl")){
    rlang::abort(
      message = "The model tibble must come from the make base tbl function.",
      use_cli_format = TRUE
    )
  }

  # Manipulation
  model_factor_tbl <- model_tbl %>%
    dplyr::mutate(.model_id = dplyr::row_number() |>
                    forcats::as_factor()) |>
    dplyr::select(.model_id, dplyr::everything())

  # Make a group split object list
  models_list <- model_factor_tbl %>%
    dplyr::group_split(.model_id)

  # Make the Workflow Object using purrr imap
  model_spec <- models_list %>%
    purrr::imap(
      .f = function(obj, id){

        # Pull the model column and then pluck the model
        pe <- obj |> dplyr::pull(2) |> purrr::pluck(1)
        pm <- obj |> dplyr::pull(3) |> purrr::pluck(1)
        pf <- obj |> dplyr::pull(4) |> purrr::pluck(1)

        ret <- match.fun(pf)(mode = pm, engine = pe)

        # Add parsnip engine and fns as class
        class(ret) <- c(
          class(ret), 
          paste0(base::tolower(pe), "_", base::tolower(pf))
          )

        # Return the result
        return(ret)
      }
    )

  # Return
  # Make sure to return as a tibble
  model_spec_ret <- model_factor_tbl |> 
    dplyr::mutate(model_spec = model_spec)  |>
    dplyr::mutate(.model_id = as.integer(.model_id))

  return(model_spec_ret)
}

Example:

library(tidyAML)
library(dplyr)
library(tidymodels)

> mod_tbl <- make_regression_base_tbl()
> mod_filtered_tbl <- mod_tbl |>
+   filter(.parsnip_engine %in% c("gee","lm") & .parsnip_fns == "linear_reg")
> 
> mod_spec_tblv1 <- mod_filtered_tbl |>
+   internal_make_spec_tbl()
> 
> mod_spec_tblv2 <- mod_filtered_tbl |>
+   internal_make_spec_tblv2()
> 
> mod_spec_tblv1
# A tibble: 2 × 5
  .model_id .parsnip_engine .parsnip_mode .parsnip_fns model_spec
      <int> <chr>           <chr>         <chr>        <list>    
1         1 lm              regression    linear_reg   <spec[+]> 
2         2 gee             regression    linear_reg   <spec[+]> 
> mod_spec_tblv2
# A tibble: 2 × 5
  .model_id .parsnip_engine .parsnip_mode .parsnip_fns model_spec
      <int> <chr>           <chr>         <chr>        <list>    
1         1 lm              regression    linear_reg   <spec[+]> 
2         2 gee             regression    linear_reg   <spec[+]> 
> mod_spec_tblv1 |>
+   dplyr::pull(model_spec) |>
+   purrr::map(class)
[[1]]
[1] "linear_reg" "model_spec"

[[2]]
[1] "linear_reg" "model_spec"

> mod_spec_tblv2 |>
+   dplyr::pull(model_spec) |>
+   purrr::map(class)
[[1]]
[1] "linear_reg"    "model_spec"    "lm_linear_reg"

[[2]]
[1] "linear_reg"     "model_spec"     "gee_linear_reg"
spsanderson commented 8 months ago

The only difference is the additional class

image