Factor variables treated as character in impute_mode recipe step #56

szimmer commented 1 month ago

The problem

If you have a factor variable and use step_impute_mode(), the formula generated treats it as a character variable. In the example below, cyl is made into a factor but on the impute step of the formula, the following line is generated:

#> • cyl = dplyr::if_else(is.na(cyl), "8", cyl)

Perhaps this should instead be generated as one of the following:

Reproducible example


mtcars_f <- mtcars |>

rec_spec <- recipe(mpg ~ ., data = mtcars_f) |>
  step_impute_mode(all_nominal_predictors()) |>

lm_spec <- linear_reg()

wf_spec <- workflow(rec_spec, lm_spec)

wf_fit <- fit(wf_spec, mtcars_f)

orbital_obj <- orbital(wf_fit)

#> ── orbital Object ──────────────────────────────────────────────────────────────
#> • cyl = dplyr::if_else(is.na(cyl), "8", cyl)
#> • disp = (disp - 230.7219) / 123.9387
#> • hp = (hp - 146.6875) / 68.56287
#> • drat = (drat - 3.596562) / 0.5346787
#> • wt = (wt - 3.21725) / 0.9784574
#> • qsec = (qsec - 17.84875) / 1.786943
#> • vs = (vs - 0.4375) / 0.5040161
#> • am = (am - 0.40625) / 0.4989909
#> • gear = (gear - 3.6875) / 0.7378041
#> • carb = (carb - 2.8125) / 1.6152
#> • .pred = 19.73744 + (ifelse(cyl == "6", 1, 0) * -1.660307) + (ifelse(cyl ...
#> ────────────────────────────────────────────────────────────────────────────────
#> 11 equations in total.

predict(orbital_obj, as_tibble(mtcars))
#> Error in `dplyr::mutate()`:
#> ℹ In argument: `cyl = dplyr::if_else(is.na(cyl), "8", cyl)`.
#> Caused by error in `dplyr::if_else()`:
#> ! Can't combine `true` <character> and `false` <double>.

Created on 2024-08-28 with reprex v2.1.0

szimmer commented 1 month ago

Closed because I was typing too fast in original example.


mtcars_f <- mtcars |>

rec_spec <- recipe(mpg ~ ., data = mtcars_f) |>
  step_impute_mode(all_nominal_predictors()) |>

lm_spec <- linear_reg()

wf_spec <- workflow(rec_spec, lm_spec)

wf_fit <- fit(wf_spec, mtcars_f)

orbital_obj <- orbital(wf_fit)

#> ── orbital Object ──────────────────────────────────────────────────────────────
#> • cyl = dplyr::if_else(is.na(cyl), "8", cyl)
#> • disp = (disp - 230.7219) / 123.9387
#> • hp = (hp - 146.6875) / 68.56287
#> • drat = (drat - 3.596562) / 0.5346787
#> • wt = (wt - 3.21725) / 0.9784574
#> • qsec = (qsec - 17.84875) / 1.786943
#> • vs = (vs - 0.4375) / 0.5040161
#> • am = (am - 0.40625) / 0.4989909
#> • gear = (gear - 3.6875) / 0.7378041
#> • carb = (carb - 2.8125) / 1.6152
#> • .pred = 19.73744 + (ifelse(cyl == "6", 1, 0) * -1.660307) + (ifelse(cyl ...
#> ────────────────────────────────────────────────────────────────────────────────
#> 11 equations in total.

predict(orbital_obj, as_tibble(mtcars_f))
#> # A tibble: 32 × 1
#>    .pred
#>    <dbl>
#>  1  21.8
#>  2  21.2
#>  3  26.3
#>  4  19.6
#>  5  17.7
#>  6  19.0
#>  7  14.2
#>  8  23.7
#>  9  24.1
#> 10  18.5
#> # ℹ 22 more rows

Created on 2024-08-28 with reprex v2.1.0

