Closed szimmer closed 1 month ago
Closed because I was typing too fast in original example.
library(orbital)
library(tidymodels)
mtcars_f <- mtcars |>
dplyr::mutate(cyl=as.factor(cyl))
rec_spec <- recipe(mpg ~ ., data = mtcars_f) |>
step_impute_mode(all_nominal_predictors()) |>
step_normalize(all_numeric_predictors())
lm_spec <- linear_reg()
wf_spec <- workflow(rec_spec, lm_spec)
wf_fit <- fit(wf_spec, mtcars_f)
orbital_obj <- orbital(wf_fit)
orbital_obj
#>
#> ── orbital Object ──────────────────────────────────────────────────────────────
#> • cyl = dplyr::if_else(is.na(cyl), "8", cyl)
#> • disp = (disp - 230.7219) / 123.9387
#> • hp = (hp - 146.6875) / 68.56287
#> • drat = (drat - 3.596562) / 0.5346787
#> • wt = (wt - 3.21725) / 0.9784574
#> • qsec = (qsec - 17.84875) / 1.786943
#> • vs = (vs - 0.4375) / 0.5040161
#> • am = (am - 0.40625) / 0.4989909
#> • gear = (gear - 3.6875) / 0.7378041
#> • carb = (carb - 2.8125) / 1.6152
#> • .pred = 19.73744 + (ifelse(cyl == "6", 1, 0) * -1.660307) + (ifelse(cyl ...
#> ────────────────────────────────────────────────────────────────────────────────
#> 11 equations in total.
predict(orbital_obj, as_tibble(mtcars_f))
#> # A tibble: 32 × 1
#> .pred
#> <dbl>
#> 1 21.8
#> 2 21.2
#> 3 26.3
#> 4 19.6
#> 5 17.7
#> 6 19.0
#> 7 14.2
#> 8 23.7
#> 9 24.1
#> 10 18.5
#> # ℹ 22 more rows
Created on 2024-08-28 with reprex v2.1.0
The problem
If you have a factor variable and use
step_impute_mode()
, the formula generated treats it as a character variable. In the example below, cyl is made into a factor but on the impute step of the formula, the following line is generated:#> • cyl = dplyr::if_else(is.na(cyl), "8", cyl)
Perhaps this should instead be generated as one of the following:
cyl = dplyr::if_else(is.na(cyl), 8, cyl)
cyl = dplyr::if_else(is.na(cyl), "8", as.character(cyl))
Reproducible example
Created on 2024-08-28 with reprex v2.1.0
Session info
``` r sessionInfo() #> R version 4.4.1 (2024-06-14 ucrt) #> Platform: x86_64-w64-mingw32/x64 #> Running under: Windows 10 x64 (build 19045) #> #> Matrix products: default #> #> #> locale: #> [1] LC_COLLATE=English_United States.utf8 #> [2] LC_CTYPE=English_United States.utf8 #> [3] LC_MONETARY=English_United States.utf8 #> [4] LC_NUMERIC=C #> [5] LC_TIME=English_United States.utf8 #> #> time zone: America/New_York #> tzcode source: internal #> #> attached base packages: #> [1] stats graphics grDevices utils datasets methods base #> #> other attached packages: #> [1] yardstick_1.3.1 workflowsets_1.1.0 workflows_1.1.4 tune_1.2.1 #> [5] tidyr_1.3.1 tibble_3.2.1 rsample_1.2.1 recipes_1.1.0 #> [9] purrr_1.0.2 parsnip_1.2.1 modeldata_1.4.0 infer_1.0.7 #> [13] ggplot2_3.5.1 dplyr_1.1.4 dials_1.3.0 scales_1.3.0 #> [17] broom_1.0.6 tidymodels_1.2.0 orbital_0.2.0 #> #> loaded via a namespace (and not attached): #> [1] tidyselect_1.2.1 timeDate_4032.109 R.utils_2.12.3 #> [4] fastmap_1.2.0 reprex_2.1.0 digest_0.6.36 #> [7] rpart_4.1.23 timechange_0.3.0 lifecycle_1.0.4 #> [10] survival_3.6-4 magrittr_2.0.3 compiler_4.4.1 #> [13] rlang_1.1.4 tools_4.4.1 utf8_1.2.4 #> [16] yaml_2.3.8 data.table_1.15.4 knitr_1.47 #> [19] DiceDesign_1.10 R.cache_0.16.0 withr_3.0.0 #> [22] R.oo_1.26.0 nnet_7.3-19 grid_4.4.1 #> [25] fansi_1.0.6 colorspace_2.1-0 future_1.34.0 #> [28] globals_0.16.3 iterators_1.0.14 MASS_7.3-60.2 #> [31] cli_3.6.3 rmarkdown_2.27 generics_0.1.3 #> [34] rstudioapi_0.16.0 future.apply_1.11.2 splines_4.4.1 #> [37] parallel_4.4.1 vctrs_0.6.5 hardhat_1.4.0 #> [40] Matrix_1.7-0 listenv_0.9.1 foreach_1.5.2 #> [43] gower_1.0.1 glue_1.7.0 parallelly_1.38.0 #> [46] codetools_0.2-20 lubridate_1.9.3 gtable_0.3.5 #> [49] munsell_0.5.1 GPfit_1.0-8 styler_1.10.3 #> [52] pillar_1.9.0 furrr_0.3.1 htmltools_0.5.8.1 #> [55] ipred_0.9-15 lava_1.8.0 R6_2.5.1 #> [58] lhs_1.2.0 tidypredict_0.5 evaluate_0.24.0 #> [61] lattice_0.22-6 R.methodsS3_1.8.2 backports_1.5.0 #> [64] class_7.3-22 Rcpp_1.0.12 prodlim_2024.06.25 #> [67] xfun_0.45 fs_1.6.4 pkgconfig_2.0.3 ```