spsanderson / tidyAML

Auto ML for the tidyverse
http://www.spsanderson.com/tidyAML/
Other
64 stars 7 forks source link

Add a function of `internal_make_wflw_gee_lin_reg()` #167

Closed spsanderson closed 9 months ago

spsanderson commented 9 months ago

Add function internal_make_wflw_gee_lin_reg()

internal_make_wflw_gee_lin_reg <- function(.model_tbl, .rec_obj){

  # Tidyeval ----
  model_tbl <- .model_tbl
  rec_obj <- .rec_obj
  mod_atb <- attributes(model_tbl$model_spec[[1]])

  # Checks ----
  if (!inherits(model_tbl, "tidyaml_mod_spec_tbl")){
    rlang::abort(
      message = "'.model_tbl' must inherit a class of 'tidyaml_mod_spec_tbl",
      use_cli_format = TRUE
    )
  }

  if (!mod_atb$.tidyaml_mod_class == "gee_linear_reg"){
    rlang::abort(
      message = "The model class is not 'gee_linear_reg'.",
      use_cli_format = TRUE
    )
  }

  # Manipulation
  model_factor_tbl <- model_tbl |>
    dplyr::mutate(.model_id = forcats::as_factor(.model_id)) |>
    dplyr::mutate(rec_obj = list(rec_obj))

  # Make a group split object list
  models_list <- model_factor_tbl |>
    dplyr::group_split(.model_id)

  # Make the Workflow Object using purrr imap
  wflw_list <- models_list |>
    purrr::imap(
      .f = function(obj, id){

        # Pull the model column and then pluck the model
        mod <- obj |> dplyr::pull(5) |> purrr::pluck(1)

        # PUll the recipe column and then pluck the recipe
        rec_obj <- obj |> dplyr::pull(6) |> purrr::pluck(1)

        # Make New formula
        # Make a formula
        my_formula <- formula(recipes::prep(rec_obj))
        predictor_vars <- rec_obj$var_info |>
          dplyr::filter(role == "predictor") |>
          dplyr::pull(variable)
        var_to_replace <- rec_obj$var_info |> 
          dplyr::filter(role == "predictor") |> 
          dplyr::slice(1) |> 
          dplyr::pull(variable)
        outcome_var <- rec_obj$var_info |>
          dplyr::filter(role == "outcome") |>
          dplyr::pull(variable)

        new_terms <-  paste0("id_var(", var_to_replace, ")")
        new_terms1 <- paste(new_terms, collapse = "+")
        new_formula <- do.call(
          "substitute", 
          list(
            my_formula, 
            stats::setNames(
              list(
                str2lang(new_terms1)
              ), 
              var_to_replace
            )
          )
        )
        new_formula <- stats::as.formula(new_formula)

        # Create a safe add_model function
        safe_add_model <- purrr::safely(
          workflows::add_model,
          otherwise = NULL,
          quiet = TRUE
        )

        # Return the workflow object with recipe and model
        ret <- workflows::workflow() |>
          workflows::add_variables(
            outcomes = outcome_var,
            predictors = predictor_vars
            ) |>
          safe_add_model(mod, formula = new_formula)

        # Pluck the result
        res <- ret |> purrr::pluck("result")

        if (!is.null(ret$error)) message(stringr::str_glue("{ret$error}"))

        # Return the result
        return(res)
      }
    )

  # Return
  return(wflw_list)
}