tidymodels / orbital

Turn Tidymodels Workflows Into Series of Equations
https://orbital.tidymodels.org
Other
16 stars 1 forks source link

Factor variables treated as character in impute_mode recipe step #56

Closed szimmer closed 1 month ago

szimmer commented 1 month ago

The problem

If you have a factor variable and use step_impute_mode(), the formula generated treats it as a character variable. In the example below, cyl is made into a factor but on the impute step of the formula, the following line is generated:

#> • cyl = dplyr::if_else(is.na(cyl), "8", cyl)

Perhaps this should instead be generated as one of the following:

Reproducible example

library(orbital)
library(tidymodels)

mtcars_f <- mtcars |>
  dplyr::mutate(cyl=as.factor(cyl))

rec_spec <- recipe(mpg ~ ., data = mtcars_f) |>
  step_impute_mode(all_nominal_predictors()) |>
  step_normalize(all_numeric_predictors())

lm_spec <- linear_reg()

wf_spec <- workflow(rec_spec, lm_spec)

wf_fit <- fit(wf_spec, mtcars_f)

orbital_obj <- orbital(wf_fit)

orbital_obj
#> 
#> ── orbital Object ──────────────────────────────────────────────────────────────
#> • cyl = dplyr::if_else(is.na(cyl), "8", cyl)
#> • disp = (disp - 230.7219) / 123.9387
#> • hp = (hp - 146.6875) / 68.56287
#> • drat = (drat - 3.596562) / 0.5346787
#> • wt = (wt - 3.21725) / 0.9784574
#> • qsec = (qsec - 17.84875) / 1.786943
#> • vs = (vs - 0.4375) / 0.5040161
#> • am = (am - 0.40625) / 0.4989909
#> • gear = (gear - 3.6875) / 0.7378041
#> • carb = (carb - 2.8125) / 1.6152
#> • .pred = 19.73744 + (ifelse(cyl == "6", 1, 0) * -1.660307) + (ifelse(cyl ...
#> ────────────────────────────────────────────────────────────────────────────────
#> 11 equations in total.

predict(orbital_obj, as_tibble(mtcars))
#> Error in `dplyr::mutate()`:
#> ℹ In argument: `cyl = dplyr::if_else(is.na(cyl), "8", cyl)`.
#> Caused by error in `dplyr::if_else()`:
#> ! Can't combine `true` <character> and `false` <double>.

Created on 2024-08-28 with reprex v2.1.0

Session info ``` r sessionInfo() #> R version 4.4.1 (2024-06-14 ucrt) #> Platform: x86_64-w64-mingw32/x64 #> Running under: Windows 10 x64 (build 19045) #> #> Matrix products: default #> #> #> locale: #> [1] LC_COLLATE=English_United States.utf8 #> [2] LC_CTYPE=English_United States.utf8 #> [3] LC_MONETARY=English_United States.utf8 #> [4] LC_NUMERIC=C #> [5] LC_TIME=English_United States.utf8 #> #> time zone: America/New_York #> tzcode source: internal #> #> attached base packages: #> [1] stats graphics grDevices utils datasets methods base #> #> other attached packages: #> [1] yardstick_1.3.1 workflowsets_1.1.0 workflows_1.1.4 tune_1.2.1 #> [5] tidyr_1.3.1 tibble_3.2.1 rsample_1.2.1 recipes_1.1.0 #> [9] purrr_1.0.2 parsnip_1.2.1 modeldata_1.4.0 infer_1.0.7 #> [13] ggplot2_3.5.1 dplyr_1.1.4 dials_1.3.0 scales_1.3.0 #> [17] broom_1.0.6 tidymodels_1.2.0 orbital_0.2.0 #> #> loaded via a namespace (and not attached): #> [1] tidyselect_1.2.1 timeDate_4032.109 R.utils_2.12.3 #> [4] fastmap_1.2.0 reprex_2.1.0 digest_0.6.36 #> [7] rpart_4.1.23 timechange_0.3.0 lifecycle_1.0.4 #> [10] survival_3.6-4 magrittr_2.0.3 compiler_4.4.1 #> [13] rlang_1.1.4 tools_4.4.1 utf8_1.2.4 #> [16] yaml_2.3.8 data.table_1.15.4 knitr_1.47 #> [19] DiceDesign_1.10 R.cache_0.16.0 withr_3.0.0 #> [22] R.oo_1.26.0 nnet_7.3-19 grid_4.4.1 #> [25] fansi_1.0.6 colorspace_2.1-0 future_1.34.0 #> [28] globals_0.16.3 iterators_1.0.14 MASS_7.3-60.2 #> [31] cli_3.6.3 rmarkdown_2.27 generics_0.1.3 #> [34] rstudioapi_0.16.0 future.apply_1.11.2 splines_4.4.1 #> [37] parallel_4.4.1 vctrs_0.6.5 hardhat_1.4.0 #> [40] Matrix_1.7-0 listenv_0.9.1 foreach_1.5.2 #> [43] gower_1.0.1 glue_1.7.0 parallelly_1.38.0 #> [46] codetools_0.2-20 lubridate_1.9.3 gtable_0.3.5 #> [49] munsell_0.5.1 GPfit_1.0-8 styler_1.10.3 #> [52] pillar_1.9.0 furrr_0.3.1 htmltools_0.5.8.1 #> [55] ipred_0.9-15 lava_1.8.0 R6_2.5.1 #> [58] lhs_1.2.0 tidypredict_0.5 evaluate_0.24.0 #> [61] lattice_0.22-6 R.methodsS3_1.8.2 backports_1.5.0 #> [64] class_7.3-22 Rcpp_1.0.12 prodlim_2024.06.25 #> [67] xfun_0.45 fs_1.6.4 pkgconfig_2.0.3 ```
szimmer commented 1 month ago

Closed because I was typing too fast in original example.

library(orbital)
library(tidymodels)

mtcars_f <- mtcars |>
  dplyr::mutate(cyl=as.factor(cyl))

rec_spec <- recipe(mpg ~ ., data = mtcars_f) |>
  step_impute_mode(all_nominal_predictors()) |>
  step_normalize(all_numeric_predictors())

lm_spec <- linear_reg()

wf_spec <- workflow(rec_spec, lm_spec)

wf_fit <- fit(wf_spec, mtcars_f)

orbital_obj <- orbital(wf_fit)

orbital_obj
#> 
#> ── orbital Object ──────────────────────────────────────────────────────────────
#> • cyl = dplyr::if_else(is.na(cyl), "8", cyl)
#> • disp = (disp - 230.7219) / 123.9387
#> • hp = (hp - 146.6875) / 68.56287
#> • drat = (drat - 3.596562) / 0.5346787
#> • wt = (wt - 3.21725) / 0.9784574
#> • qsec = (qsec - 17.84875) / 1.786943
#> • vs = (vs - 0.4375) / 0.5040161
#> • am = (am - 0.40625) / 0.4989909
#> • gear = (gear - 3.6875) / 0.7378041
#> • carb = (carb - 2.8125) / 1.6152
#> • .pred = 19.73744 + (ifelse(cyl == "6", 1, 0) * -1.660307) + (ifelse(cyl ...
#> ────────────────────────────────────────────────────────────────────────────────
#> 11 equations in total.

predict(orbital_obj, as_tibble(mtcars_f))
#> # A tibble: 32 × 1
#>    .pred
#>    <dbl>
#>  1  21.8
#>  2  21.2
#>  3  26.3
#>  4  19.6
#>  5  17.7
#>  6  19.0
#>  7  14.2
#>  8  23.7
#>  9  24.1
#> 10  18.5
#> # ℹ 22 more rows

Created on 2024-08-28 with reprex v2.1.0

Session info ``` r sessionInfo() #> R version 4.4.1 (2024-06-14 ucrt) #> Platform: x86_64-w64-mingw32/x64 #> Running under: Windows 10 x64 (build 19045) #> #> Matrix products: default #> #> #> locale: #> [1] LC_COLLATE=English_United States.utf8 #> [2] LC_CTYPE=English_United States.utf8 #> [3] LC_MONETARY=English_United States.utf8 #> [4] LC_NUMERIC=C #> [5] LC_TIME=English_United States.utf8 #> #> time zone: America/New_York #> tzcode source: internal #> #> attached base packages: #> [1] stats graphics grDevices utils datasets methods base #> #> other attached packages: #> [1] yardstick_1.3.1 workflowsets_1.1.0 workflows_1.1.4 tune_1.2.1 #> [5] tidyr_1.3.1 tibble_3.2.1 rsample_1.2.1 recipes_1.1.0 #> [9] purrr_1.0.2 parsnip_1.2.1 modeldata_1.4.0 infer_1.0.7 #> [13] ggplot2_3.5.1 dplyr_1.1.4 dials_1.3.0 scales_1.3.0 #> [17] broom_1.0.6 tidymodels_1.2.0 orbital_0.2.0 #> #> loaded via a namespace (and not attached): #> [1] tidyselect_1.2.1 timeDate_4032.109 R.utils_2.12.3 #> [4] fastmap_1.2.0 reprex_2.1.0 digest_0.6.36 #> [7] rpart_4.1.23 timechange_0.3.0 lifecycle_1.0.4 #> [10] survival_3.6-4 magrittr_2.0.3 compiler_4.4.1 #> [13] rlang_1.1.4 tools_4.4.1 utf8_1.2.4 #> [16] yaml_2.3.8 data.table_1.15.4 knitr_1.47 #> [19] DiceDesign_1.10 R.cache_0.16.0 withr_3.0.0 #> [22] R.oo_1.26.0 nnet_7.3-19 grid_4.4.1 #> [25] fansi_1.0.6 colorspace_2.1-0 future_1.34.0 #> [28] globals_0.16.3 iterators_1.0.14 MASS_7.3-60.2 #> [31] cli_3.6.3 rmarkdown_2.27 generics_0.1.3 #> [34] rstudioapi_0.16.0 future.apply_1.11.2 splines_4.4.1 #> [37] parallel_4.4.1 vctrs_0.6.5 hardhat_1.4.0 #> [40] Matrix_1.7-0 listenv_0.9.1 foreach_1.5.2 #> [43] gower_1.0.1 glue_1.7.0 parallelly_1.38.0 #> [46] codetools_0.2-20 lubridate_1.9.3 gtable_0.3.5 #> [49] munsell_0.5.1 GPfit_1.0-8 styler_1.10.3 #> [52] pillar_1.9.0 furrr_0.3.1 htmltools_0.5.8.1 #> [55] ipred_0.9-15 lava_1.8.0 R6_2.5.1 #> [58] lhs_1.2.0 tidypredict_0.5 evaluate_0.24.0 #> [61] lattice_0.22-6 R.methodsS3_1.8.2 backports_1.5.0 #> [64] class_7.3-22 Rcpp_1.0.12 prodlim_2024.06.25 #> [67] xfun_0.45 fs_1.6.4 pkgconfig_2.0.3 ```