spsanderson / tidyAML

Auto ML for the tidyverse
http://www.spsanderson.com/tidyAML/
Other
63 stars 7 forks source link

`gee` model #143

Closed spsanderson closed 9 months ago

spsanderson commented 1 year ago

The gee engine must have a formula like in the following:

linear_reg() %>% 
  set_engine("gee") %>% 
  fit(breaks ~ tension + id_var(wool), data = warpbreaks)

It cannot be like fit(breaks ~ ., data = warpbreaks)

This will nescessitate a few changes. Parameters must be added to fast_regression() and fast_classification() and sub-modules must be built to handle the extra call necessary inside of internal_make_workflow() and possibly internal_make_fitted_wflw()

Using something like this somewhere will help:

outcome_col <- rec_obj$var_info |> filter(role == "outcome") |> pull(variable)
pred_cols <- rec_obj$var_info |> filter(role == "predictor") |> pull(variable)

as.formula(paste(outcome_col, " ~ ", paste(pred_cols, collapse= "+")))

Turn the above into a function that does something like make_full_function(.rec_obj = rec_obj)

spsanderson commented 9 months ago

This seems to work:

library(tidyAML)
library(tidyverse)
library(tidymodels)
library(multilevelmod)

rec_obj <- recipe(mpg ~ ., data = mtcars)
splits_obj <- create_splits(mtcars)

my_formula <- formula(prep(rec_obj))
predictor_vars <- rec_obj$var_info |>
  filter(role == "predictor") |>
  pull(variable)
var_to_replace <- rec_obj$var_info |> 
  filter(role == "predictor") |> 
  slice(1) |> 
  pull(variable)
outcome_var <- rec_obj$var_info |>
  filter(role == "outcome") |>
  pull(variable)

new_terms <-  paste0("id_var(", var_to_replace, ")")
new_terms1 <- paste(new_terms, collapse = "+")
new_formula <- do.call(
  "substitute", 
  list(
    my_formula, 
    setNames(
      list(
        str2lang(new_terms1)
      ), 
      var_to_replace
    )
  )
)
new_formula <- as.formula(new_formula)

mod_tbl <- fast_regression_parsnip_spec_tbl(
  .parsnip_eng = c("gee"),
  .parsnip_fns = "linear_reg"
)

mod_spec <- mod_tbl[["model_spec"]][[1]]
mod_wflw_tbl <- mod_tbl |>
  mutate(wflw = list(workflow() |>
                       add_variables(
                         outcomes = outcome_var,
                         predictors = c(predictor_vars)
                       ) |>
                       add_model(mod_spec, formula = new_formula)))

mod_fitted_tbl <- mod_wflw_tbl |>
  mutate(fitted_wflw = list(fit(mod_wflw_tbl$wflw[[1]], data = training(splits_obj$splits))))

mod_pred_tbl <- mod_fitted_tbl |>
  mutate(pred_wflw = list(predict(mod_fitted_tbl[[7]][[1]], testing(splits_obj$splits))))

mod_pred_tbl[[8]][[1]]

# A tibble: 8 × 1
  .pred
  <dbl>
1 21.6 
2 18.5 
3 16.2 
4  8.27
5 26.7 
6 29.2 
7 24.3 
8 18.0