spsanderson / tidyAML

Auto ML for the tidyverse
http://www.spsanderson.com/tidyAML/
Other
63 stars 7 forks source link

Make a standardized way of getting residuals #187

Closed spsanderson closed 8 months ago

spsanderson commented 8 months ago

It is possible that a particular model does not have a standard way of getting residuals. For example we can have a linear model using the engine "lm" and the fns linear_reg() and it is possible to obtain the residuals from the fitted_wflw column using something like:

> l[[1]][["fit"]][["fit"]][["fit"]][["residuals"]] |> as_tibble() |> set_names(".resid")
# A tibble: 24 × 1
   .resid
    <dbl>
 1  1.03 
 2 -0.902
 3  3.24 
 4  0.785
 5 -1.40 
 6 -0.504
 7  0.166
 8 -2.35 
 9 -4.59 
10 -0.668
# ℹ 14 more rows
# ℹ Use `print(n = ...)` to see more rows

Where l is some list of fitted workflows, but, some models may not fit that form, for instance, partykit.

> l[[14]]
══ Workflow [trained] ═══════════════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: decision_tree()

── Preprocessor ─────────────────────────────────────────────────────────────────────────────────────
2 Recipe Steps

• step_dummy()
• step_normalize()

── Model ────────────────────────────────────────────────────────────────────────────────────────────

Model formula:
..y ~ cyl + disp + hp + drat + wt + qsec + vs + am + gear + carb

Fitted party:
[1] root
|   [2] wt <= -0.84777: 29.029 (n = 7, err = 89.8)
|   [3] wt > -0.84777: 17.106 (n = 17, err = 174.3)

Number of inner nodes:    1
Number of terminal nodes: 2

> l[[14]][["fit"]][["fit"]][["fit"]][["residuals"]] |> as_tibble() |> set_names(".resid")
Error in `[[.party`(l[[14]][["fit"]][["fit"]][["fit"]], "residuals") : 
  length(i) == 1 & is.numeric(i) is not TRUE

Maybe a function should be made to do this an alternate route still relying on broom but using broom::augment(new_data = .data)

filtered_tbl |>
  unnest(cols = res) |>
  mutate(pfe = paste0(.parsnip_engine, " - ", .parsnip_fns)) |>
  mutate(.res = mpg - .pred) |>
  select(pfe, mpg, .pred, .res)

# A tibble: 768 × 4
   pfe               mpg .pred   .res
   <chr>           <dbl> <dbl>  <dbl>
 1 lm - linear_reg  21    22.4 -1.36 
 2 lm - linear_reg  21    21.7 -0.692
 3 lm - linear_reg  22.8  27.4 -4.59 
 4 lm - linear_reg  21.4  21.9 -0.504
 5 lm - linear_reg  18.7  17.4  1.25 
 6 lm - linear_reg  18.1  21.4 -3.26 
 7 lm - linear_reg  14.3  15.0 -0.668
 8 lm - linear_reg  24.4  22.0  2.42 
 9 lm - linear_reg  22.8  24.5 -1.70 
10 lm - linear_reg  19.2  18.2  1.03 
# ℹ 758 more rows
# ℹ Use `print(n = ...)` to see more rows
fr_tbl <- fast_regression(
  .data = df,
  .rec_obj = recipe,
  .parsnip_fns = c("linear_reg", "mars", "bag_mars", "rand_forest",
                   "boost_tree", "bag_tree"),
  .parsnip_eng = c("lm", "gee", "glm", "gls", "earth", "rpart", "lightgbm")
)

fr_tbl |>
  mutate(res = map(fitted_wflw, \(x) x |> 
                     broom::augment(new_data = df))) |>
  unnest(cols = res) |>
  mutate(pfe = paste0(.parsnip_engine, " - ", .parsnip_fns)) |>
  mutate(.res = mpg - .pred)
spsanderson commented 8 months ago

Need a way of extracting the outcome variable in order to make the computation of .resid

spsanderson commented 8 months ago

This can work but needs to come from the final pred_wflw column:

x <- internal_make_wflw_predictions(mod_fitted_tbl, splits_obj) |> purrr::pluck(1)

x |> 
  dplyr::select(-.data_type) |> 
  tidyr::pivot_wider(names_from = .data_category, values_from = .value, values_fn = list) |> 
  tidyr::unnest(cols = dplyr::everything())

# A tibble: 32 × 2
   actual predicted
    <dbl>     <dbl>
 1   21        26.0
 2   21        28.2
 3   22.8      14.7
 4   21.4      19.6
 5   18.7      14.3
 6   18.1      17.3
 7   14.3      21.3
 8   24.4      30.5
 9   22.8      17.4
10   19.2      19.3
# ℹ 22 more rows
# ℹ Use `print(n = ...)` to see more rows
spsanderson commented 8 months ago

This should work for regression, another is needed for classification.

Function:

extract_regression_residuals <- function(.model_tbl) {

  # Checks
  if (!inherits(.model_tbl, "fst_reg_spec_tbl")) {
    rlang::abort(
      message = "Input must be from fast regression.",
      use_last = TRUE
    )
  }

  if (!"pred_wflw" %in% names(.model_tbl)) {
    rlang::abort(
      message = "Input must be from fast regression.",
      use_last = TRUE
    )
  }

  # Manipulation
  model_factor_tbl <- .model_tbl |>
    dplyr::mutate(.model_id = forcats::as_factor(.model_id))

  models_list <- model_factor_tbl |>
    dplyr::group_split(.model_id)

  # Extract residuals
  residuals_list <- models_list |>
    purrr::imap(.f = function(obj, id){

      # Get model type
      pe <- obj |> dplyr::pull(.parsnip_engine) |> purrr::pluck(1)
      pf <- obj |> dplyr::pull(.parsnip_fns) |> purrr::pluck(1)
      pfe <- paste0(pe, " - ", pf)

      # Extract actual and predicted values
      ap_tbl <- obj |> 
        dplyr::pull(pred_wflw) |> 
        purrr::pluck(1) |>
        dplyr::select(-.data_type) |>
        tidyr::pivot_wider(
          names_from = .data_category, 
          values_from = .value, 
          values_fn = list) |> 
        tidyr::unnest(cols = dplyr::everything()) |>
        dplyr::mutate(
          .resid = actual - predicted,
          .model_type = pfe
        ) |>
        dplyr::select(.model_type, actual, predicted, .resid) |>
        purrr::set_names(c(".model_type", ".actual", ".predicted", ".resid"))

      return(ap_tbl)
    })

  return(residuals_list)

}

Example:

library(tidyAML)
library(tidyverse)
library(tidymodels)

tidymodels_prefer()
load_deps()

mod_spec_tbl <- fast_regression_parsnip_spec_tbl(
 .parsnip_eng = c("lm","glm","gee"),
 .parsnip_fns = "linear_reg"
)

rec_obj <- recipe(mpg ~ ., data = mtcars)
splits_obj <- create_splits(mtcars, "initial_split")

mod_tbl <- mod_spec_tbl |>
 mutate(wflw = full_internal_make_wflw(mod_spec_tbl, rec_obj))

mod_fitted_tbl <- mod_tbl |>
 mutate(fitted_wflw = internal_make_fitted_wflw(mod_tbl, splits_obj))

mod_pred_tbl <- mod_fitted_tbl |>
  mutate(pred_wflw = internal_make_wflw_predictions(mod_fitted_tbl, splits_obj))

extract_regression_residuals(mod_pred_tbl)

> extract_regression_residuals(mod_pred_tbl)
[[1]]
# A tibble: 32 × 4
   .model_type     .actual .predicted  .resid
   <chr>             <dbl>      <dbl>   <dbl>
 1 lm - linear_reg    21         19.2   1.82 
 2 lm - linear_reg    21         19.2   1.77 
 3 lm - linear_reg    22.8       13.4   9.40 
 4 lm - linear_reg    21.4       28.1  -6.67 
 5 lm - linear_reg    18.7       18.5   0.179
 6 lm - linear_reg    18.1       13.1   5.05 
 7 lm - linear_reg    14.3       16.4  -2.14 
 8 lm - linear_reg    24.4       23.4   0.989
 9 lm - linear_reg    22.8       16.6   6.24 
10 lm - linear_reg    19.2       31.8 -12.6  
# ℹ 22 more rows
# ℹ Use `print(n = ...)` to see more rows

[[2]]
# A tibble: 32 × 4
   .model_type      .actual .predicted  .resid
   <chr>              <dbl>      <dbl>   <dbl>
 1 gee - linear_reg    21         19.3   1.66 
 2 gee - linear_reg    21         19.3   1.75 
 3 gee - linear_reg    22.8       13.4   9.44 
 4 gee - linear_reg    21.4       28.1  -6.73 
 5 gee - linear_reg    18.7       18.5   0.164
 6 gee - linear_reg    18.1       12.9   5.19 
 7 gee - linear_reg    14.3       16.5  -2.19 
 8 gee - linear_reg    24.4       23.5   0.894
 9 gee - linear_reg    22.8       16.6   6.25 
10 gee - linear_reg    19.2       31.9 -12.7  
# ℹ 22 more rows
# ℹ Use `print(n = ...)` to see more rows

[[3]]
# A tibble: 32 × 4
   .model_type      .actual .predicted  .resid
   <chr>              <dbl>      <dbl>   <dbl>
 1 glm - linear_reg    21         19.2   1.82 
 2 glm - linear_reg    21         19.2   1.77 
 3 glm - linear_reg    22.8       13.4   9.40 
 4 glm - linear_reg    21.4       28.1  -6.67 
 5 glm - linear_reg    18.7       18.5   0.179
 6 glm - linear_reg    18.1       13.1   5.05 
 7 glm - linear_reg    14.3       16.4  -2.14 
 8 glm - linear_reg    24.4       23.4   0.989
 9 glm - linear_reg    22.8       16.6   6.24 
10 glm - linear_reg    19.2       31.8 -12.6  
# ℹ 22 more rows
# ℹ Use `print(n = ...)` to see more rows