spsanderson / tidyAML

Auto ML for the tidyverse
http://www.spsanderson.com/tidyAML/
Other
66 stars 7 forks source link

Make ` fast_regression_to_parsnip_call_tbl()` function #5

Closed spsanderson closed 1 year ago

spsanderson commented 1 year ago

This function will generate a tibble with simple default regression calls in the form of:

parsnip::linear_reg(mode = "regression", engine = "some_supplied_engine")

The result is a tibble.

spsanderson commented 1 year ago

Function:

library(tidymodels)
library(tidyverse)

fast_regression_to_parsnip_call_tbl <- function(.parsnip_fns = "all",
                                                .parsnip_eng = "all") {

  # Thank you https://stackoverflow.com/questions/74691333/build-a-tibble-of-parsnip-model-calls-with-match-fun/74691529#74691529  
  # Tidyeval ----
  pf <- list(.parsnip_fns) %>% 
    purrr::flatten_chr()
  pe <- list(.parsnip_eng) %>%
    purrr::flatten_chr()

  # Make tibble
  mod_tbl <- tibble::tibble(
    .parsnip_engine = c(
      "lm",
      "brulee",
      "gee",
      "glm",
      "glmer",
      "glmnet",
      "gls",
      "h2o",
      "keras",
      "lme",
      "lmer",
      "spark",
      "stan",
      "stan_glmer",
      "Cubist",
      "glm",
      "gee",
      "glmer",
      "glmnet",
      "h2o",
      "hurdle",
      "stan",
      "stan_glmer",
      "zeroinfl",
      "survival",
      "flexsurv",
      "flexsurvspline"
    ),
    .parsnip_mode = c(
      rep("regression", 24), 
      rep("censored regression", 3)
    ),
    .parsnip_fns = c(
      rep("linear_reg", 14),
      "cubist_rules",
      rep("poisson_reg",9),
      rep("survival_reg", 3)
    )
  ) 

  # Filter ----
  if (!"all" %in% pe){
    mod_tbl <- mod_tbl %>%
      dplyr::filter(.parsnip_engine %in% pe)
  }

  if (!"all" %in% pf){
    mod_tbl <- mod_tbl %>%
      dplyr::filter(.parsnip_fns %in% pf)
  }

  mod_filtered_tbl <- mod_tbl

  mod_spec_tbl <- mod_filtered_tbl %>%
    dplyr::mutate(
      .model_spec = purrr::pmap(
        dplyr::cur_data(),
        ~ match.fun(..3)(mode = ..2, engine = ..1)
      )
    )

  # Return ----
  class(mod_spec_tbl) <- c("fst_reg_spec_tbl", class(mod_spec_tbl))
  attr(mod_spec_tbl, ".parsnip_engines") <- .parsnip_eng
  attr(mod_spec_tbl, ".parsnip_functions") <- .parsnip_fns

  return(mod_spec_tbl)

}

Examples:

> frt <- fast_regression_to_parsnip_call_tbl()
> frt
# A tibble: 27 × 4
   .parsnip_engine .parsnip_mode .parsnip_fns .model_spec
   <chr>           <chr>         <chr>        <list>     
 1 lm              regression    linear_reg   <spec[+]>  
 2 brulee          regression    linear_reg   <spec[+]>  
 3 gee             regression    linear_reg   <spec[+]>  
 4 glm             regression    linear_reg   <spec[+]>  
 5 glmer           regression    linear_reg   <spec[+]>  
 6 glmnet          regression    linear_reg   <spec[+]>  
 7 gls             regression    linear_reg   <spec[+]>  
 8 h2o             regression    linear_reg   <spec[+]>  
 9 keras           regression    linear_reg   <spec[+]>  
10 lme             regression    linear_reg   <spec[+]>  
# … with 17 more rows
# ℹ Use `print(n = ...)` to see more rows
> attributes(frt)
$names
[1] ".parsnip_engine" ".parsnip_mode"   ".parsnip_fns"    ".model_spec"    

$row.names
 [1]  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
[27] 27

$class
[1] "fst_reg_spec_tbl" "tbl_df"           "tbl"              "data.frame"      

$.parsnip_engines
[1] "all"

$.parsnip_functions
[1] "all"

> class(frt)
[1] "fst_reg_spec_tbl" "tbl_df"           "tbl"              "data.frame"      
> fast_regression_to_parsnip_call_tbl(.parsnip_fns = "cubist_rules")
# A tibble: 1 × 4
  .parsnip_engine .parsnip_mode .parsnip_fns .model_spec
  <chr>           <chr>         <chr>        <list>     
1 Cubist          regression    cubist_rules <spec[+]>  
> fast_regression_to_parsnip_call_tbl(.parsnip_eng = "lm")
# A tibble: 1 × 4
  .parsnip_engine .parsnip_mode .parsnip_fns .model_spec
  <chr>           <chr>         <chr>        <list>     
1 lm              regression    linear_reg   <spec[+]>  
> fast_regression_to_parsnip_call_tbl(.parsnip_eng = "glm")
# A tibble: 2 × 4
  .parsnip_engine .parsnip_mode .parsnip_fns .model_spec
  <chr>           <chr>         <chr>        <list>     
1 glm             regression    linear_reg   <spec[+]>  
2 glm             regression    poisson_reg  <spec[+]>  
> fast_regression_to_parsnip_call_tbl(.parsnip_fns = "poisson_reg", 
+                                     .parsnip_eng = "glm")
# A tibble: 1 × 4
  .parsnip_engine .parsnip_mode .parsnip_fns .model_spec
  <chr>           <chr>         <chr>        <list>     
1 glm             regression    poisson_reg  <spec[+]>  
> fast_regression_to_parsnip_call_tbl(.parsnip_eng = c("lm","glm"))
# A tibble: 3 × 4
  .parsnip_engine .parsnip_mode .parsnip_fns .model_spec
  <chr>           <chr>         <chr>        <list>     
1 lm              regression    linear_reg   <spec[+]>  
2 glm             regression    linear_reg   <spec[+]>  
3 glm             regression    poisson_reg  <spec[+]>