spsanderson / tidyAML

Auto ML for the tidyverse
http://www.spsanderson.com/tidyAML/
Other
63 stars 7 forks source link

Update the function `plot_regression_predictions()` to drop the `%>%` #231

Closed spsanderson closed 6 months ago

spsanderson commented 6 months ago

New Function:

plot_regression_predictions <- function(.data, .output = "list"){

  # Variables
  output <- tolower(.output)

  # Checks
  if (!output %in% c("list", "facet")) {
    rlang::abort(
      message = "output must be either 'list' or 'facet'.",
      use_last = TRUE
    )
  }

  if (!is.data.frame(.data)){
    rlang::abort(
      message = "data must be a data.frame/tibble.",
      use_last = TRUE
    )
  }

  if (!is.numeric(.data$.value)) {
    rlang::abort(
      message = ".value must be numeric.",
      use_last = TRUE
    )
  }

  # Plot
  if (output == "list") {
    p <- .data |>
      dplyr::group_split(.model_type) |>
      purrr::map(\(x) x |>
                   dplyr::group_by(.data_category) |>
                   dplyr::mutate(x = dplyr::row_number()) |>
                   dplyr::ungroup() |>
                   tidyr::pivot_wider(names_from = .data_type, values_from = .value) |>
                   ggplot2::ggplot(ggplot2::aes(x = x, y = actual, group = .data_category)) +
                   ggplot2::geom_line(color = "black") +
                   ggplot2::geom_line(ggplot2::aes(x = x, y = training),
                                      linetype = "dashed", color = "red",
                                      linewidth = 1) +
                   ggplot2::geom_line(ggplot2::aes(x = x, y = testing),
                                      linetype = "dashed", color = "blue",
                                      linewidth = 1) +
                   ggplot2::theme_minimal() +
                   ggplot2::labs(
                     x = "",
                     y = "Observed/Predicted Value",
                     title = "Observed vs. Predicted Values by Model Type",
                     subtitle = x$.model_type[1],
                     caption = "Black = Actual, Red = Training, Blue = Testing"
                   )
      )
  } else {

    df <- .data |>
      dplyr::group_by(.model_type, .data_category) |>
      dplyr::mutate(x = dplyr::row_number()) |>
      dplyr::ungroup()

    act_data <- dplyr::filter(df, .data_type == "actual")
    train_data <- dplyr::filter(df, .data_type == "training")
    test_data <- dplyr::filter(df, .data_type == "testing")

    p <- df |>
      dplyr::group_by(.model_type, .data_category) |>
      dplyr::mutate(x = dplyr::row_number()) |>
      dplyr::ungroup() |>
      ggplot2::ggplot(ggplot2::aes(x = x, y = .value)) +
      ggplot2::geom_line(data = act_data, color = "black") +
      ggplot2::geom_line(data = train_data, linetype = "dashed", color = "red") +
      ggplot2::geom_line(data = test_data, linetype = "dashed", color = "blue") +
      ggplot2::facet_wrap(~ .model_type, ncol = 2, scales = "free") +
      ggplot2::labs(
        x = "",
        y = "Observed/Predicted Value",
        title = "Observed vs. Predicted Values by Model Type",
        caption = "Black = Actual, Red = Training, Blue = Testing"
      ) +
      ggplot2::theme_minimal()
  }

  # Return
  return(p)
}