tidymodels / tune

Tools for tidy parameter tuning
https://tune.tidymodels.org
Other
266 stars 37 forks source link

fit_resamples creates one-hot encoding for some datasets and not for others #262

Closed cimentadaj closed 3 years ago

cimentadaj commented 3 years ago

The problem

In light of https://github.com/tidymodels/tune/issues/151, I'm trying to run a resampling of a continuous variable against a character column without one-hot encoding the character column. I took the stackoverflow example from https://github.com/tidymodels/tune/issues/151 and found that it worked. However, once I replicated the exact same thing for mtcars, it raises an error.

Reproducible example

Here's the example using the stackoverflow data:

library(rsample)
library(parsnip)
library(tune)
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
library(modeldata)
data(stackoverflow)

linear_spec <-
  linear_reg() %>%
  set_engine("lm") %>%
  set_mode("regression")

########################## With stackoverflow data ####################
so_split <- initial_split(stackoverflow[c("Salary", "Country")])
so_train <- training(so_split)

# Convert factor to character
so_fold <-
  mutate(so_train, Country = as.character(Country)) %>%
  vfold_cv(v = 10)

# Returns results without errors/warnings
linear_spec %>%
  fit_resamples(
    Salary ~ Country,
    resamples = so_fold 
  )
#> # Resampling results
#> # 10-fold cross-validation 
#> # A tibble: 10 x 4
#>    splits             id     .metrics         .notes          
#>    <list>             <chr>  <list>           <list>          
#>  1 <split [3.8K/420]> Fold01 <tibble [2 × 3]> <tibble [0 × 1]>
#>  2 <split [3.8K/420]> Fold02 <tibble [2 × 3]> <tibble [0 × 1]>
#>  3 <split [3.8K/420]> Fold03 <tibble [2 × 3]> <tibble [0 × 1]>
#>  4 <split [3.8K/420]> Fold04 <tibble [2 × 3]> <tibble [0 × 1]>
#>  5 <split [3.8K/420]> Fold05 <tibble [2 × 3]> <tibble [0 × 1]>
#>  6 <split [3.8K/420]> Fold06 <tibble [2 × 3]> <tibble [0 × 1]>
#>  7 <split [3.8K/419]> Fold07 <tibble [2 × 3]> <tibble [0 × 1]>
#>  8 <split [3.8K/419]> Fold08 <tibble [2 × 3]> <tibble [0 × 1]>
#>  9 <split [3.8K/419]> Fold09 <tibble [2 × 3]> <tibble [0 × 1]>
#> 10 <split [3.8K/419]> Fold10 <tibble [2 × 3]> <tibble [0 × 1]>

This is expected, as character columns are not expanded to dummies. However, if I replace the above with mtcars, it raises the typical one-hot encoding problem of not finding variables defined in the formula:

########################## With mtcars data ####################
mt_split <- initial_split(mtcars[c("mpg", "gear")])
mt_train <- as_tibble(training(mt_split))

mt_fold <-
  mt_train %>%
  mutate(gear = as.character(gear)) %>%
  vfold_cv(v = 10)

# Returns results with errors/warnings
linear_spec %>%
  fit_resamples(
    mpg ~ gear,
    resamples = mt_fold 
  )
#> x Fold01: model (predictions): Error: Can't subset columns that don't exist.
#> ✖ Col...
#> x Fold03: model (predictions): Error: Can't subset columns that don't exist.
#> ✖ Col...
#> x Fold04: model (predictions): Error: Can't subset columns that don't exist.
#> ✖ Col...
#> x Fold05: model (predictions): Error: Can't subset columns that don't exist.
#> ✖ Col...
#> x Fold06: model (predictions): Error: Can't subset columns that don't exist.
#> ✖ Col...
#> x Fold07: model (predictions): Error: Can't subset columns that don't exist.
#> ✖ Col...
#> x Fold08: model (predictions): Error: Can't subset columns that don't exist.
#> ✖ Col...
#> x Fold09: model (predictions): Error: Can't subset columns that don't exist.
#> ✖ Col...
#> x Fold10: model (predictions): Error: Can't subset columns that don't exist.
#> ✖ Col...
#> Warning: This tuning result has notes. Example notes on model fitting include:
#> model (predictions): Error: Can't subset columns that don't exist.
#> ✖ Column `gear5` doesn't exist.
#> model (predictions): Error: Can't subset columns that don't exist.
#> ✖ Column `gear5` doesn't exist.
#> model (predictions): Error: Can't subset columns that don't exist.
#> ✖ Column `gear4` doesn't exist.
#> # Resampling results
#> # 10-fold cross-validation 
#> # A tibble: 10 x 4
#>    splits         id     .metrics         .notes          
#>    <list>         <chr>  <list>           <list>          
#>  1 <split [21/3]> Fold01 <tibble [0 × 3]> <tibble [1 × 1]>
#>  2 <split [21/3]> Fold02 <tibble [2 × 3]> <tibble [0 × 1]>
#>  3 <split [21/3]> Fold03 <tibble [0 × 3]> <tibble [1 × 1]>
#>  4 <split [21/3]> Fold04 <tibble [0 × 3]> <tibble [1 × 1]>
#>  5 <split [22/2]> Fold05 <tibble [0 × 3]> <tibble [1 × 1]>
#>  6 <split [22/2]> Fold06 <tibble [0 × 3]> <tibble [1 × 1]>
#>  7 <split [22/2]> Fold07 <tibble [0 × 3]> <tibble [1 × 1]>
#>  8 <split [22/2]> Fold08 <tibble [0 × 3]> <tibble [1 × 1]>
#>  9 <split [22/2]> Fold09 <tibble [0 × 3]> <tibble [1 × 1]>
#> 10 <split [22/2]> Fold10 <tibble [0 × 3]> <tibble [1 × 1]>

I assume this is not expected, right? Some thoughts:

  1. From what I've read in https://github.com/tidymodels/workflows/pull/53, https://github.com/tidymodels/parsnip/pull/332 and https://github.com/tidymodels/hardhat/pull/140, one-hot encoding will only happen with factor columns, in case it is specified in default_formula_blueprint. I think this shouldn't happen with character columns, as it's happening now.

  2. Surprisingly, the previous error happens in some folds but not on all folds.

Since I know there have been recent merges related to the problem, I installed all latest Github versions of parsnip, tune, hardhat and rsample. Here's my SI:

devtools::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value                       
#>  version  R version 4.0.2 (2020-06-22)
#>  os       Ubuntu 20.04.1 LTS          
#>  system   x86_64, linux-gnu           
#>  ui       X11                         
#>  language (EN)                        
#>  collate  en_US.UTF-8                 
#>  ctype    en_US.UTF-8                 
#>  tz       Europe/Berlin               
#>  date     2020-08-17                  
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package     * version    date       lib source                             
#>  assertthat    0.2.1      2019-03-21 [1] CRAN (R 4.0.0)                     
#>  backports     1.1.8      2020-06-17 [1] CRAN (R 4.0.2)                     
#>  callr         3.4.3      2020-03-28 [1] CRAN (R 4.0.0)                     
#>  class         7.3-17     2020-04-26 [3] CRAN (R 4.0.0)                     
#>  cli           2.0.2      2020-02-28 [1] CRAN (R 4.0.0)                     
#>  codetools     0.2-16     2018-12-24 [3] CRAN (R 4.0.0)                     
#>  colorspace    1.4-1      2019-03-18 [1] CRAN (R 4.0.0)                     
#>  crayon        1.3.4      2017-09-16 [1] CRAN (R 4.0.0)                     
#>  desc          1.2.0      2018-05-01 [1] CRAN (R 4.0.0)                     
#>  devtools      2.3.0      2020-04-10 [1] CRAN (R 4.0.2)                     
#>  dials         0.0.8      2020-07-08 [1] CRAN (R 4.0.2)                     
#>  DiceDesign    1.8-1      2019-07-31 [1] CRAN (R 4.0.0)                     
#>  digest        0.6.25     2020-02-23 [1] CRAN (R 4.0.0)                     
#>  dplyr       * 1.0.2      2020-08-14 [1] Github (tidyverse/dplyr@0bea3e8)   
#>  ellipsis      0.3.1      2020-05-15 [1] CRAN (R 4.0.0)                     
#>  evaluate      0.14       2019-05-28 [1] CRAN (R 4.0.0)                     
#>  fansi         0.4.1      2020-01-08 [1] CRAN (R 4.0.0)                     
#>  foreach       1.5.0      2020-03-30 [1] CRAN (R 4.0.0)                     
#>  fs            1.5.0      2020-07-31 [1] CRAN (R 4.0.2)                     
#>  furrr         0.1.0      2018-05-16 [1] CRAN (R 4.0.0)                     
#>  future        1.18.0     2020-07-09 [1] CRAN (R 4.0.2)                     
#>  generics      0.0.2      2018-11-29 [1] CRAN (R 4.0.0)                     
#>  ggplot2       3.3.2      2020-06-19 [1] CRAN (R 4.0.1)                     
#>  globals       0.12.5     2019-12-07 [1] CRAN (R 4.0.0)                     
#>  glue          1.4.1      2020-05-13 [1] CRAN (R 4.0.0)                     
#>  gower         0.2.2      2020-06-23 [1] CRAN (R 4.0.2)                     
#>  GPfit         1.0-8      2019-02-08 [1] CRAN (R 4.0.0)                     
#>  gtable        0.3.0      2019-03-25 [1] CRAN (R 4.0.0)                     
#>  hardhat       0.1.4.9000 2020-08-17 [1] Github (tidymodels/hardhat@0e31502)
#>  highr         0.8        2019-03-20 [1] CRAN (R 4.0.0)                     
#>  htmltools     0.5.0      2020-06-16 [1] CRAN (R 4.0.1)                     
#>  ipred         0.9-9      2019-04-28 [1] CRAN (R 4.0.0)                     
#>  iterators     1.0.12     2019-07-26 [1] CRAN (R 4.0.0)                     
#>  knitr         1.29       2020-06-23 [1] CRAN (R 4.0.2)                     
#>  lattice       0.20-41    2020-04-02 [3] CRAN (R 4.0.0)                     
#>  lava          1.6.7      2020-03-05 [1] CRAN (R 4.0.0)                     
#>  lhs           1.0.2      2020-04-13 [1] CRAN (R 4.0.0)                     
#>  lifecycle     0.2.0      2020-03-06 [1] CRAN (R 4.0.0)                     
#>  listenv       0.8.0      2019-12-05 [1] CRAN (R 4.0.0)                     
#>  lubridate     1.7.9      2020-06-08 [1] CRAN (R 4.0.1)                     
#>  magrittr      1.5        2014-11-22 [1] CRAN (R 4.0.0)                     
#>  MASS          7.3-51.6   2020-04-26 [3] CRAN (R 4.0.0)                     
#>  Matrix        1.2-18     2019-11-27 [3] CRAN (R 4.0.0)                     
#>  memoise       1.1.0      2017-04-21 [1] CRAN (R 4.0.0)                     
#>  modeldata   * 0.0.2      2020-06-22 [1] CRAN (R 4.0.2)                     
#>  munsell       0.5.0      2018-06-12 [1] CRAN (R 4.0.0)                     
#>  nnet          7.3-14     2020-04-26 [3] CRAN (R 4.0.0)                     
#>  parsnip     * 0.1.3.9000 2020-08-17 [1] Github (tidymodels/parsnip@7a86bfd)
#>  pillar        1.4.6      2020-07-10 [1] CRAN (R 4.0.2)                     
#>  pkgbuild      1.1.0      2020-07-13 [1] CRAN (R 4.0.2)                     
#>  pkgconfig     2.0.3      2019-09-22 [1] CRAN (R 4.0.0)                     
#>  pkgload       1.1.0      2020-05-29 [1] CRAN (R 4.0.0)                     
#>  plyr          1.8.6      2020-03-03 [1] CRAN (R 4.0.0)                     
#>  prettyunits   1.1.1      2020-01-24 [1] CRAN (R 4.0.0)                     
#>  pROC          1.16.2     2020-03-19 [1] CRAN (R 4.0.0)                     
#>  processx      3.4.3      2020-07-05 [1] CRAN (R 4.0.2)                     
#>  prodlim       2019.11.13 2019-11-17 [1] CRAN (R 4.0.0)                     
#>  ps            1.3.4      2020-08-11 [1] CRAN (R 4.0.2)                     
#>  purrr         0.3.4      2020-04-17 [1] CRAN (R 4.0.0)                     
#>  R6            2.4.1      2019-11-12 [1] CRAN (R 4.0.0)                     
#>  Rcpp          1.0.5      2020-07-06 [1] CRAN (R 4.0.2)                     
#>  recipes       0.1.13     2020-06-23 [1] CRAN (R 4.0.2)                     
#>  remotes       2.1.1      2020-02-15 [1] CRAN (R 4.0.0)                     
#>  rlang         0.4.7      2020-07-09 [1] CRAN (R 4.0.2)                     
#>  rmarkdown     2.2        2020-05-31 [1] CRAN (R 4.0.0)                     
#>  rpart         4.1-15     2019-04-12 [3] CRAN (R 4.0.0)                     
#>  rprojroot     1.3-2      2018-01-03 [1] CRAN (R 4.0.0)                     
#>  rsample     * 0.0.7.9000 2020-08-17 [1] Github (tidymodels/rsample@aa0a4ac)
#>  scales        1.1.1      2020-05-11 [1] CRAN (R 4.0.1)                     
#>  sessioninfo   1.1.1      2018-11-05 [1] CRAN (R 4.0.0)                     
#>  stringi       1.4.6      2020-02-17 [1] CRAN (R 4.0.0)                     
#>  stringr       1.4.0      2019-02-10 [1] CRAN (R 4.0.0)                     
#>  survival      3.1-12     2020-04-10 [3] CRAN (R 4.0.0)                     
#>  testthat      2.3.2      2020-03-02 [1] CRAN (R 4.0.0)                     
#>  tibble        3.0.3.9000 2020-07-29 [1] Github (tidyverse/tibble@b4eec19)  
#>  tidyr         1.1.1      2020-07-31 [1] CRAN (R 4.0.2)                     
#>  tidyselect    1.1.0      2020-05-11 [1] CRAN (R 4.0.0)                     
#>  timeDate      3043.102   2018-02-21 [1] CRAN (R 4.0.0)                     
#>  tune        * 0.1.1.9000 2020-08-17 [1] Github (tidymodels/tune@6a614d8)   
#>  usethis       1.6.1      2020-04-29 [1] CRAN (R 4.0.0)                     
#>  utf8          1.1.4      2018-05-24 [1] CRAN (R 4.0.0)                     
#>  vctrs         0.3.2      2020-07-15 [1] CRAN (R 4.0.2)                     
#>  withr         2.2.0      2020-04-20 [1] CRAN (R 4.0.0)                     
#>  workflows     0.1.3      2020-08-10 [1] CRAN (R 4.0.2)                     
#>  xfun          0.16       2020-07-24 [1] CRAN (R 4.0.2)                     
#>  yaml          2.2.1      2020-02-01 [1] CRAN (R 4.0.0)                     
#>  yardstick     0.0.7      2020-07-13 [1] CRAN (R 4.0.2)                     
#> 
#> [1] /usr/local/lib/R/site-library
#> [2] /usr/lib/R/site-library
#> [3] /usr/lib/R/library
topepo commented 3 years ago

The issue is related to how you are passing the gear column in. Keeping it as character gives you different factor levels since the character is converted to factor after the data have been split. So one data set has levels "3" and "4" and then, in the new data, a new value of "5" is given to it.

It happens in some folds and not on others since this is a very small data set and you sometimes end up sampling-out a factor level. So it is random since resampling is random.

If you convert to factor (instead of character), the problem goes away since the factor is aware of all possible levels from the start:

library(tidymodels)
#> ── Attaching packages ───────────────────────────────────── tidymodels 0.1.1 ──
#> ✓ broom     0.7.0          ✓ recipes   0.1.13    
#> ✓ dials     0.0.8.9000     ✓ rsample   0.0.7     
#> ✓ dplyr     1.0.1          ✓ tibble    3.0.3     
#> ✓ ggplot2   3.3.2          ✓ tidyr     1.1.1     
#> ✓ infer     0.5.2          ✓ tune      0.1.1.9000
#> ✓ modeldata 0.0.2          ✓ workflows 0.1.3     
#> ✓ parsnip   0.1.3          ✓ yardstick 0.0.7     
#> ✓ purrr     0.3.4
#> Warning: package 'recipes' was built under R version 4.0.2
#> Warning: package 'rsample' was built under R version 4.0.2
#> Warning: package 'workflows' was built under R version 4.0.2
#> Warning: package 'yardstick' was built under R version 4.0.2
#> ── Conflicts ──────────────────────────────────────── tidymodels_conflicts() ──
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter()  masks stats::filter()
#> x dplyr::lag()     masks stats::lag()
#> x recipes::step()  masks stats::step()

linear_spec <-
 linear_reg() %>%
 set_engine("lm") %>%
 set_mode("regression")

mt_split <- initial_split(mtcars[c("mpg", "gear")])
mt_train <- as_tibble(training(mt_split))

mt_fold <-
 mt_train %>%
 mutate(gear = factor(gear)) %>%
 vfold_cv(v = 10)

# Returns results with errors/warnings
linear_spec %>%
 fit_resamples(
  mpg ~ gear,
  resamples = mt_fold 
 )
#> ! Fold01: internal: A correlation computation is required, but `estimate` is const...
#> ! Fold02: internal: A correlation computation is required, but `estimate` is const...
#> ! Fold04: internal: A correlation computation is required, but `estimate` is const...
#> ! Fold05: internal: A correlation computation is required, but `truth` is constant...
#> ! Fold08: internal: A correlation computation is required, but `estimate` is const...
#> Warning: This tuning result has notes. Example notes on model fitting include:
#> internal: A correlation computation is required, but `estimate` is constant and has 0 standard deviation, resulting in a divide by 0 error. `NA` will be returned.
#> internal: A correlation computation is required, but `truth` is constant and has 0 standard deviation, resulting in a divide by 0 error. `NA` will be returned.
#> internal: A correlation computation is required, but `estimate` is constant and has 0 standard deviation, resulting in a divide by 0 error. `NA` will be returned.
#> # Resampling results
#> # 10-fold cross-validation 
#> # A tibble: 10 x 4
#>    splits         id     .metrics         .notes          
#>    <list>         <chr>  <list>           <list>          
#>  1 <split [21/3]> Fold01 <tibble [2 × 3]> <tibble [1 × 1]>
#>  2 <split [21/3]> Fold02 <tibble [2 × 3]> <tibble [1 × 1]>
#>  3 <split [21/3]> Fold03 <tibble [2 × 3]> <tibble [0 × 1]>
#>  4 <split [21/3]> Fold04 <tibble [2 × 3]> <tibble [1 × 1]>
#>  5 <split [22/2]> Fold05 <tibble [2 × 3]> <tibble [1 × 1]>
#>  6 <split [22/2]> Fold06 <tibble [2 × 3]> <tibble [0 × 1]>
#>  7 <split [22/2]> Fold07 <tibble [2 × 3]> <tibble [0 × 1]>
#>  8 <split [22/2]> Fold08 <tibble [2 × 3]> <tibble [1 × 1]>
#>  9 <split [22/2]> Fold09 <tibble [2 × 3]> <tibble [0 × 1]>
#> 10 <split [22/2]> Fold10 <tibble [2 × 3]> <tibble [0 × 1]>

Created on 2020-08-17 by the reprex package (v0.3.0)

The new warning that is issued is due to the small data set since sampling-out 2 of the three levels of gear. This results in an intercept only model, which predicts the same value for all samples. Since the R^2 statistic depends on the variance of the predicted values, it ends up dividing by zero (and issuing the warning).

DavisVaughan commented 3 years ago

To add to Max's answer

This is expected, as character columns are not expanded to dummies

I don't think this is quite right. Character columns are converted to factors and are then expanded. You can see that Country has been expanded in the coefficients below

suppressPackageStartupMessages({
  library(rsample)
  library(parsnip)
  library(tune)
  library(dplyr)
  library(modeldata)
})
data(stackoverflow)

linear_spec <-
  linear_reg() %>%
  set_engine("lm") %>%
  set_mode("regression")

so_split <- initial_split(stackoverflow[c("Salary", "Country")])
so_train <- training(so_split)

# Convert factor to character
so_fold <-
  mutate(so_train, Country = as.character(Country)) %>%
  vfold_cv(v = 10)

# Returns results without errors/warnings
mods <- linear_spec %>%
  fit_resamples(
    Salary ~ Country,
    resamples = so_fold,
    control = control_resamples(extract = identity)
  )

mods$.extracts[[1]]$.extracts[[1]]$fit$fit$fit
#> 
#> Call:
#> stats::lm(formula = ..y ~ ., data = data)
#> 
#> Coefficients:
#>             (Intercept)           CountryGermany             CountryIndia  
#>                   56795                    -4576                   -45025  
#> `CountryUnited Kingdom`   `CountryUnited States`  
#>                   -2720                    41485

Created on 2020-08-17 by the reprex package (v0.3.0.9001)

github-actions[bot] commented 3 years ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.