Closed spsanderson closed 2 years ago
Set data:
library(tidyverse)
data_tbl <- data.frame(
B = c(0, 2, 4),
A = c(1, 3, 5),
C = c(2, 4, 6)
) %>%
as_tibble()
Function hai_polynomial_augment:
hai_polynomial_augment <- function(.data, .formula = NULL, .pred_col = NULL
, .degree = 1, .new_col_prefix = "nt_"){
# Tidyeval ----
f <- .formula
d <- as.integer(.degree)
pred_col_var_expr <- rlang::enquo(.pred_col)
ncp <- .new_col_prefix
# Manipulate ----
# Ensure that the 'y' column is the first column of the data.frame/tibble
data_tbl <- .data %>%
tibble::as_tibble() %>%
dplyr::select({{ pred_col_var_expr }}, dplyr::everything())
# Checks ----
if(!is.null(f)){
f = as.formula(f)
} else if(
!rlang::quo_is_missing(pred_col_var_expr) &
!rlang::quo_is_null(pred_col_var_expr) &
!is.null(d) &
is.integer(d)
){
f = reformulate(
paste0(
'poly(',
colnames(data_tbl[-1]),
', ',
d,
')'
)
, response = y
)
} else {
stop(
"There is an issue with how you entered your parameters. Please fix.",
"\nYou have .formula = ", .formula,
"\nYou have .pred_col = ", .pred_col,
"\nYou have .degree = ", .degree,
"\nIf you have .formula = NULL, then you must set .pred_col AND .degree."
)
}
if(!is.character(ncp)){
stop(".new_col_prefix must be a quoted character string")
} else {
ncp <- ncp
}
# Augment ----
mm <- stats::model.matrix(f, data = data_tbl)
mm_df <- mm %>% base::as.data.frame() %>% janitor::clean_names()
new_mm_col_names <- paste0(ncp, names(mm_df))
colnames(mm_df) <- new_mm_col_names
data_tbl <- cbind(data_tbl, mm_df) %>% tibble::as_tibble()
# Return ----
message("The formula used is: ", deparse(f))
return(data_tbl)
}
Use:
> hai_polynomial_augment(.data = data_tbl, .pred_col = A, .degree = 2, .new_col_prefix = "n")
The formula used is: A ~ poly(B, 2) + poly(C, 2)
# A tibble: 3 x 8
A B C nintercept npoly_b_2_1 npoly_b_2_2 npoly_c_2_1 npoly_c_2_2
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 1 0 2 1 -7.07e- 1 0.408 -7.07e- 1 0.408
2 3 2 4 1 -7.85e-17 -0.816 -7.85e-17 -0.816
3 5 4 6 1 7.07e- 1 0.408 7.07e- 1 0.408
> hai_polynomial_augment(.data = data_tbl, .formula = A ~ .^2, .degree = 1)
The formula used is: A ~ .^2
# A tibble: 3 x 7
B A C nt_intercept nt_b nt_c nt_b_c
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 0 1 2 1 0 2 0
2 2 3 4 1 2 4 8
3 4 5 6 1 4 6 24
https://pycaret.org/polynomial-features/