tidymodels / recipes

Pipeable steps for feature engineering and data preprocessing to prepare for modeling
https://recipes.tidymodels.org
Other
573 stars 112 forks source link

`step_impute_knn()` errors at bake time on character columns that work at prep time #926

Open mdsteiner opened 2 years ago

mdsteiner commented 2 years ago

The problem

When imputing missing values with step_impute_knn(all_predictors()) the error Error in gower_work(x = x, y = y, pair_x = pair_x, pair_y = pair_y, n = n, : Column 2 of x is of class character while matching column 2 of y is of class factor is thrown when calling the predict.workflows() function. The recipe seems to be applied correctly in the fitting process, but not in the predict function. A workaround is to call step_string2factor(all_nominal_predictors()) before the step_impute_knn(all_predictors()) in the recipe but given that this is not necessary in the fitting process it may be desirable to have the same behavior when calling predict.workflow().

Reproducible example

library(tibble)
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
tidymodels_prefer()

# set up data
set.seed(42)
dat <- tibble(
  criterion = rnorm(50),
  num_pred_a = rnorm(50) + .8*criterion,
  num_pred_b = rnorm(50) + .6*criterion,
  char_pred = ifelse(criterion < .2,
                     sample(c("a", "b"), 1, prob = c(.75, .25)),
                     sample(c("a", "b"), 1, prob = c(.5, .5))))

dat[sample(1:nrow(dat), 8), 2] <- NA
dat[sample(1:nrow(dat), 8), 4] <- NA

dat_split <- initial_split(dat)
dat_train <- training(dat_split)
dat_test <- testing(dat_split)

# create recipe
lm_recipe <- 
  recipe(criterion ~ ., data = dat_train) %>% 
  step_impute_knn(all_predictors()) %>% 
  step_other(all_nominal_predictors()) %>% 
  step_dummy(all_nominal_predictors())

# set up the regression model
lm_model <- 
  linear_reg() %>% 
  set_engine("lm") %>% 
  set_mode("regression")

# lm workflow 
lm_workflow <- 
  workflow() %>% 
  add_recipe(lm_recipe) %>% 
  add_model(lm_model)

# Fit the regression model
lm_fit <-
  lm_workflow %>% 
  fit(dat_train)

# get predicted values
predict(lm_fit, new_data = dat_test)
#> Error in gower_work(x = x, y = y, pair_x = pair_x, pair_y = pair_y, n = n, : Column 2 of x is of class character while matching column 2 of y is of class factor

sessioninfo::session_info()
#> - Session info ---------------------------------------------------------------
#>  setting  value
#>  version  R version 4.1.2 (2021-11-01)
#>  os       Windows 10 x64 (build 19042)
#>  system   x86_64, mingw32
#>  ui       RTerm
#>  language (EN)
#>  collate  German_Switzerland.1252
#>  ctype    German_Switzerland.1252
#>  tz       Europe/Berlin
#>  date     2022-03-08
#>  pandoc   2.14.0.3 @ C:/Program Files/RStudio/bin/pandoc/ (via rmarkdown)
#> 
#> - Packages -------------------------------------------------------------------
#>  package      * version    date (UTC) lib source
#>  assertthat     0.2.1      2019-03-21 [1] CRAN (R 4.1.0)
#>  backports      1.4.1      2021-12-13 [1] CRAN (R 4.1.2)
#>  broom        * 0.7.12     2022-01-28 [1] CRAN (R 4.1.2)
#>  cachem         1.0.6      2021-08-19 [1] CRAN (R 4.1.2)
#>  class          7.3-20     2022-01-13 [1] CRAN (R 4.1.2)
#>  cli            3.2.0      2022-02-14 [1] CRAN (R 4.1.2)
#>  codetools      0.2-18     2020-11-04 [1] CRAN (R 4.1.2)
#>  colorspace     2.0-3      2022-02-21 [1] CRAN (R 4.1.2)
#>  conflicted     1.1.0      2021-11-26 [1] CRAN (R 4.1.2)
#>  crayon         1.5.0      2022-02-14 [1] CRAN (R 4.1.2)
#>  DBI            1.1.2      2021-12-20 [1] CRAN (R 4.1.2)
#>  dials        * 0.1.0      2022-01-31 [1] CRAN (R 4.1.2)
#>  DiceDesign     1.9        2021-02-13 [1] CRAN (R 4.1.1)
#>  digest         0.6.29     2021-12-01 [1] CRAN (R 4.1.2)
#>  dplyr        * 1.0.8      2022-02-08 [1] CRAN (R 4.1.2)
#>  ellipsis       0.3.2      2021-04-29 [1] CRAN (R 4.1.0)
#>  evaluate       0.15       2022-02-18 [1] CRAN (R 4.1.2)
#>  fansi          1.0.2      2022-01-14 [1] CRAN (R 4.1.2)
#>  fastmap        1.1.0      2021-01-25 [1] CRAN (R 4.1.0)
#>  foreach        1.5.2      2022-02-02 [1] CRAN (R 4.1.2)
#>  fs             1.5.2      2021-12-08 [1] CRAN (R 4.1.2)
#>  furrr          0.2.3      2021-06-25 [1] CRAN (R 4.1.1)
#>  future         1.24.0     2022-02-19 [1] CRAN (R 4.1.2)
#>  future.apply   1.8.1      2021-08-10 [1] CRAN (R 4.1.1)
#>  generics       0.1.2      2022-01-31 [1] CRAN (R 4.1.2)
#>  ggplot2      * 3.3.5      2021-06-25 [1] CRAN (R 4.1.0)
#>  globals        0.14.0     2020-11-22 [1] CRAN (R 4.1.0)
#>  glue           1.6.2      2022-02-24 [1] CRAN (R 4.1.2)
#>  gower          1.0.0      2022-02-03 [1] CRAN (R 4.1.2)
#>  GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.1.1)
#>  gtable         0.3.0      2019-03-25 [1] CRAN (R 4.1.0)
#>  hardhat        0.2.0      2022-01-24 [1] CRAN (R 4.1.2)
#>  highr          0.9        2021-04-16 [1] CRAN (R 4.1.0)
#>  htmltools      0.5.2      2021-08-25 [1] CRAN (R 4.1.2)
#>  infer        * 1.0.0      2021-08-13 [1] CRAN (R 4.1.1)
#>  ipred          0.9-12     2021-09-15 [1] CRAN (R 4.1.2)
#>  iterators      1.0.14     2022-02-05 [1] CRAN (R 4.1.1)
#>  knitr          1.37       2021-12-16 [1] CRAN (R 4.1.2)
#>  lattice        0.20-45    2021-09-22 [1] CRAN (R 4.1.2)
#>  lava           1.6.10     2021-09-02 [1] CRAN (R 4.1.2)
#>  lhs            1.1.4      2022-02-20 [1] CRAN (R 4.1.2)
#>  lifecycle      1.0.1      2021-09-24 [1] CRAN (R 4.1.1)
#>  listenv        0.8.0      2019-12-05 [1] CRAN (R 4.1.0)
#>  lubridate      1.8.0      2021-10-07 [1] CRAN (R 4.1.2)
#>  magrittr       2.0.2      2022-01-26 [1] CRAN (R 4.1.2)
#>  MASS           7.3-55     2022-01-13 [1] CRAN (R 4.1.2)
#>  Matrix         1.4-0      2021-12-08 [1] CRAN (R 4.1.2)
#>  memoise        2.0.1      2021-11-26 [1] CRAN (R 4.1.2)
#>  modeldata    * 0.1.1      2021-07-14 [1] CRAN (R 4.1.1)
#>  munsell        0.5.0      2018-06-12 [1] CRAN (R 4.1.0)
#>  nnet           7.3-17     2022-01-13 [1] CRAN (R 4.1.2)
#>  parallelly     1.30.0     2021-12-17 [1] CRAN (R 4.1.2)
#>  parsnip      * 0.1.7      2021-07-21 [1] CRAN (R 4.1.1)
#>  pillar         1.7.0      2022-02-01 [1] CRAN (R 4.1.2)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.1.0)
#>  plyr           1.8.6      2020-03-03 [1] CRAN (R 4.1.0)
#>  pROC           1.18.0     2021-09-03 [1] CRAN (R 4.1.2)
#>  prodlim        2019.11.13 2019-11-17 [1] CRAN (R 4.1.1)
#>  purrr        * 0.3.4      2020-04-17 [1] CRAN (R 4.1.0)
#>  R6             2.5.1      2021-08-19 [1] CRAN (R 4.1.1)
#>  Rcpp           1.0.8      2022-01-13 [1] CRAN (R 4.1.2)
#>  recipes      * 0.2.0      2022-02-18 [1] CRAN (R 4.1.2)
#>  reprex         2.0.1      2021-08-05 [1] CRAN (R 4.1.2)
#>  rlang          1.0.2      2022-03-04 [1] CRAN (R 4.1.2)
#>  rmarkdown      2.12       2022-03-02 [1] CRAN (R 4.1.2)
#>  rpart          4.1.16     2022-01-24 [1] CRAN (R 4.1.2)
#>  rsample      * 0.1.1      2021-11-08 [1] CRAN (R 4.1.2)
#>  rstudioapi     0.13       2020-11-12 [1] CRAN (R 4.1.0)
#>  scales       * 1.1.1      2020-05-11 [1] CRAN (R 4.1.0)
#>  sessioninfo    1.2.2      2021-12-06 [1] CRAN (R 4.1.2)
#>  stringi        1.7.6      2021-11-29 [1] CRAN (R 4.1.2)
#>  stringr        1.4.0      2019-02-10 [1] CRAN (R 4.1.0)
#>  survival       3.3-1      2022-03-03 [1] CRAN (R 4.1.2)
#>  tibble       * 3.1.6      2021-11-07 [1] CRAN (R 4.1.2)
#>  tidymodels   * 0.1.4      2021-10-01 [1] CRAN (R 4.1.2)
#>  tidyr        * 1.2.0      2022-02-01 [1] CRAN (R 4.1.2)
#>  tidyselect     1.1.2      2022-02-21 [1] CRAN (R 4.1.2)
#>  timeDate       3043.102   2018-02-21 [1] CRAN (R 4.1.0)
#>  tune         * 0.1.6      2021-07-21 [1] CRAN (R 4.1.1)
#>  utf8           1.2.2      2021-07-24 [1] CRAN (R 4.1.0)
#>  vctrs          0.3.8      2021-04-29 [1] CRAN (R 4.1.0)
#>  withr          2.5.0      2022-03-03 [1] CRAN (R 4.1.2)
#>  workflows    * 0.2.4      2021-10-12 [1] CRAN (R 4.1.2)
#>  workflowsets * 0.1.0      2021-07-22 [1] CRAN (R 4.1.1)
#>  xfun           0.30       2022-03-02 [1] CRAN (R 4.1.2)
#>  yaml           2.3.5      2022-02-21 [1] CRAN (R 4.1.2)
#>  yardstick    * 0.0.9      2021-11-22 [1] CRAN (R 4.1.2)

Created on 2022-03-08 by the reprex package (v2.0.1)

juliasilge commented 2 years ago

Thank you for this report! 🙌 Overall in recipes we have some problems around how factors are handled such as #331, #715, and unfortunately others. We should plan to fix this problem that you reported together along with our overall factor problems.

DavisVaughan commented 2 years ago

Reproducible with only recipes, so I'm going to move it there:

library(tibble)
library(recipes)

# set up data
set.seed(42)

dat <- tibble(
  criterion = rnorm(50),
  num_pred_a = rnorm(50) + .8*criterion,
  char_pred = ifelse(
    criterion < .2,
    sample(c("a", "b"), 1, prob = c(.75, .25)),
    sample(c("a", "b"), 1, prob = c(.5, .5))
  )
)

dat[sample(1:nrow(dat), 8), 2] <- NA
dat[sample(1:nrow(dat), 8), 3] <- NA

rec <- recipe(criterion ~ ., data = dat) %>% 
  step_impute_knn(all_predictors())

rec_prepped <- prep(rec, dat)

bake(rec_prepped, dat)
#> Error in gower_work(x = x, y = y, pair_x = pair_x, pair_y = pair_y, n = n, : Column 1 of x is of class character while matching column 1 of y is of class factor

Created on 2022-03-09 by the reprex package (v2.0.1)

JosiahParry commented 1 year ago

I am running into this same issue but with step_other() using an ordered factor. I'm using step_other() on an integer (age field) that previously was accepted by step_other(). It has now been cast as an ordered to keep up with changes.

EmilHvitfeldt commented 1 year ago

@JosiahParry would you be able to produce a reprex? If this is true we might have a larger issue at hand

zlk0822 commented 2 months ago

I solved this issue by placing step_impute_knn later in the sequence of steps. Therefore, I believe the correct approach should be to perform transformations first and then imputation