topepo / caret

caret (Classification And Regression Training) R package that contains misc functions for training and plotting classification and regression models
http://topepo.github.io/caret/index.html
1.61k stars 634 forks source link

train with recipes breaks if any roles are NA #989

Open exuberantleigh opened 5 years ago

exuberantleigh commented 5 years ago

Alright Davis, time to tell me what I'm doing wrong :P .

train() is breaking if there are any roles in the recipe that are NA. Looks like it happens when it checks for case weights (case weights are not being specified here).

(Also just noticed a bunch of warnings about using newdata instead of new_data in bake, which I'll delete the duplicates of in the reprex so it's not so cluttered.)

suppressMessages(library(dplyr))
suppressMessages(library(recipes))
suppressMessages(library(caret))

rec <- recipes::recipe(head(iris)) %>%
  update_role(Sepal.Width, new_role = "predictor") %>%
  update_role(Sepal.Length, new_role = "outcome")
rec %>% prep() %>% summary()
#> # A tibble: 5 x 4
#>   variable     type    role      source  
#>   <chr>        <chr>   <chr>     <chr>   
#> 1 Sepal.Length numeric outcome   original
#> 2 Sepal.Width  numeric predictor original
#> 3 Petal.Length numeric <NA>      original
#> 4 Petal.Width  numeric <NA>      original
#> 5 Species      nominal <NA>      original

tc <- caret::trainControl(method = "cv", number = 10)
set.seed(1234)
caret::train(rec, iris, method = "lm", trControl = tc, metric = "RMSE")
#> Error in if (any(is_weight)) {: missing value where TRUE/FALSE needed
# bascially replicating what train.recipe does when checking for case weights ---
roles <- rec %>% prep() %>% summary() %>%
  purrr::pluck("role")
is_weight <- roles == "case weight"
any(is_weight)
#> [1] NA

if(any(is_weight)) print("not gonna happen")
#> Error in if (any(is_weight)) print("not gonna happen"): missing value where TRUE/FALSE needed
# work around by assigning all other variables to a role so there are no NAs ---
rec_removeNA <- rec %>%
  update_role(-contains("Sepal"), new_role = "other")
rec_removeNA %>% prep() %>% summary()
#> # A tibble: 5 x 4
#>   variable     type    role      source  
#>   <chr>        <chr>   <chr>     <chr>   
#> 1 Sepal.Length numeric outcome   original
#> 2 Sepal.Width  numeric predictor original
#> 3 Petal.Length numeric other     original
#> 4 Petal.Width  numeric other     original
#> 5 Species      nominal other     original

set.seed(1234)
caret::train(rec_removeNA, iris, method = "lm", trControl = tc, metric = "RMSE")
#> Warning: Please use `new_data` instead of `newdata` with `bake`. 
#> In recipes versions >= 0.1.4, this will cause an error.

#> In recipes versions >= 0.1.4, this will cause an error.
#> Linear Regression 
#> 
#> 150 samples
#>   4 predictor
#> 
#> Recipe steps:  
#> Resampling: Cross-Validated (10 fold) 
#> Summary of sample sizes: 135, 136, 133, 135, 135, 134, ... 
#> Resampling results:
#> 
#>   RMSE       Rsquared    MAE      
#>   0.8201493  0.09020067  0.6808931
#> 
#> Tuning parameter 'intercept' was held constant at a value of TRUE

# compare to formula method just for fun ---------------------------------------
rec_form <- recipes::recipe(Sepal.Length ~ Sepal.Width, data = head(iris))
rec_form %>% prep() %>% summary()
#> # A tibble: 2 x 4
#>   variable     type    role      source  
#>   <chr>        <chr>   <chr>     <chr>   
#> 1 Sepal.Width  numeric predictor original
#> 2 Sepal.Length numeric outcome   original

set.seed(1234)
caret::train(rec_form, iris, method = "lm", trControl = tc, metric = "RMSE")
#> Linear Regression 
#> 
#> 150 samples
#>   4 predictor
#> 
#> Recipe steps:  
#> Resampling: Cross-Validated (10 fold) 
#> Summary of sample sizes: 135, 136, 133, 135, 135, 134, ... 
#> Resampling results:
#> 
#>   RMSE       Rsquared    MAE      
#>   0.8201493  0.09020067  0.6808931
#> 
#> Tuning parameter 'intercept' was held constant at a value of TRUE

Created on 2019-01-17 by the reprex package (v0.2.1)

Session info ``` r devtools::session_info() #> ─ Session info ────────────────────────────────────────────────────────── #> setting value #> version R version 3.5.1 (2018-07-02) #> os macOS 10.14.2 #> system x86_64, darwin15.6.0 #> ui X11 #> language (EN) #> collate en_US.UTF-8 #> ctype en_US.UTF-8 #> tz America/Chicago #> date 2019-01-17 #> #> ─ Packages ────────────────────────────────────────────────────────────── #> package * version date lib source #> assertthat 0.2.0 2017-04-11 [2] CRAN (R 3.5.0) #> backports 1.1.3 2018-12-14 [1] CRAN (R 3.5.0) #> bindr 0.1.1 2018-03-13 [1] CRAN (R 3.5.0) #> bindrcpp * 0.2.2 2018-03-29 [1] CRAN (R 3.5.0) #> callr 3.1.0 2018-12-10 [1] CRAN (R 3.5.0) #> caret * 6.0-80 2018-05-26 [1] CRAN (R 3.5.1) #> class 7.3-14 2015-08-30 [2] CRAN (R 3.5.1) #> cli 1.0.1 2018-09-25 [1] CRAN (R 3.5.0) #> codetools 0.2-15 2016-10-05 [2] CRAN (R 3.5.1) #> colorspace 1.3-2 2016-12-14 [1] CRAN (R 3.5.0) #> crayon 1.3.4 2017-09-16 [1] CRAN (R 3.5.0) #> data.table 1.11.8 2018-09-30 [1] CRAN (R 3.5.0) #> desc 1.2.0 2018-05-01 [2] CRAN (R 3.5.0) #> devtools 2.0.1 2018-10-26 [1] CRAN (R 3.5.1) #> digest 0.6.18 2018-10-10 [1] CRAN (R 3.5.0) #> dplyr * 0.7.8 2018-11-10 [1] CRAN (R 3.5.0) #> evaluate 0.12 2018-10-09 [1] CRAN (R 3.5.0) #> fansi 0.4.0 2018-10-05 [1] CRAN (R 3.5.0) #> foreach 1.4.4 2017-12-12 [1] CRAN (R 3.5.0) #> fs 1.2.6 2018-08-23 [1] CRAN (R 3.5.0) #> generics 0.0.2 2018-11-29 [1] CRAN (R 3.5.0) #> ggplot2 * 3.1.0 2018-10-25 [1] CRAN (R 3.5.0) #> glue 1.3.0 2018-07-17 [2] CRAN (R 3.5.0) #> gower 0.1.2 2017-02-23 [1] CRAN (R 3.5.0) #> gtable 0.2.0 2016-02-26 [1] CRAN (R 3.5.0) #> htmltools 0.3.6 2017-04-28 [1] CRAN (R 3.5.0) #> ipred 0.9-8 2018-11-05 [1] CRAN (R 3.5.0) #> iterators 1.0.10 2018-07-13 [1] CRAN (R 3.5.0) #> knitr 1.20 2018-02-20 [1] CRAN (R 3.5.0) #> lattice * 0.20-35 2017-03-25 [1] CRAN (R 3.5.0) #> lava 1.6.4 2018-11-25 [1] CRAN (R 3.5.0) #> lazyeval 0.2.1 2017-10-29 [1] CRAN (R 3.5.0) #> lubridate 1.7.4 2018-04-11 [1] CRAN (R 3.5.0) #> magrittr 1.5 2014-11-22 [1] CRAN (R 3.5.0) #> MASS 7.3-51 2018-10-16 [1] CRAN (R 3.5.0) #> Matrix 1.2-14 2018-04-13 [2] CRAN (R 3.5.1) #> memoise 1.1.0 2017-04-21 [1] CRAN (R 3.5.0) #> ModelMetrics 1.2.2 2018-11-03 [1] CRAN (R 3.5.0) #> munsell 0.5.0 2018-06-12 [1] CRAN (R 3.5.0) #> nlme 3.1-137 2018-04-07 [1] CRAN (R 3.5.0) #> nnet 7.3-12 2016-02-02 [2] CRAN (R 3.5.1) #> pillar 1.3.1 2018-12-15 [1] CRAN (R 3.5.0) #> pkgbuild 1.0.2 2018-10-16 [1] CRAN (R 3.5.0) #> pkgconfig 2.0.2 2018-08-16 [1] CRAN (R 3.5.0) #> pkgload 1.0.2 2018-10-29 [1] CRAN (R 3.5.0) #> plyr 1.8.4 2016-06-08 [1] CRAN (R 3.5.0) #> prettyunits 1.0.2 2015-07-13 [1] CRAN (R 3.5.0) #> processx 3.2.1 2018-12-05 [1] CRAN (R 3.5.0) #> prodlim 2018.04.18 2018-04-18 [1] CRAN (R 3.5.0) #> ps 1.2.1 2018-11-06 [1] CRAN (R 3.5.0) #> purrr 0.2.5 2018-05-29 [1] CRAN (R 3.5.0) #> R6 2.3.0 2018-10-04 [2] CRAN (R 3.5.0) #> Rcpp 1.0.0 2018-11-07 [1] CRAN (R 3.5.0) #> recipes * 0.1.4 2018-11-19 [1] CRAN (R 3.5.1) #> remotes 2.0.2 2018-10-30 [1] CRAN (R 3.5.1) #> reshape2 1.4.3 2017-12-11 [1] CRAN (R 3.5.0) #> rlang 0.3.1 2019-01-08 [1] CRAN (R 3.5.2) #> rmarkdown 1.10 2018-06-11 [1] CRAN (R 3.5.0) #> rpart 4.1-13 2018-02-23 [2] CRAN (R 3.5.1) #> rprojroot 1.3-2 2018-01-03 [2] CRAN (R 3.5.0) #> scales 1.0.0 2018-08-09 [1] CRAN (R 3.5.0) #> sessioninfo 1.1.0 2018-09-25 [1] CRAN (R 3.5.0) #> stringi 1.2.4 2018-07-20 [1] CRAN (R 3.5.0) #> stringr 1.3.1 2018-05-10 [1] CRAN (R 3.5.0) #> survival 2.42-3 2018-04-16 [2] CRAN (R 3.5.1) #> testthat 2.0.1 2018-10-13 [1] CRAN (R 3.5.0) #> tibble 2.0.0 2019-01-04 [1] CRAN (R 3.5.2) #> tidyr 0.8.2 2018-10-28 [1] CRAN (R 3.5.0) #> tidyselect 0.2.5 2018-10-11 [1] CRAN (R 3.5.0) #> timeDate 3043.102 2018-02-21 [1] CRAN (R 3.5.0) #> usethis 1.4.0 2018-08-14 [1] CRAN (R 3.5.0) #> utf8 1.1.4 2018-05-24 [1] CRAN (R 3.5.0) #> withr 2.1.2 2018-03-15 [1] CRAN (R 3.5.0) #> yaml 2.2.0 2018-07-25 [1] CRAN (R 3.5.0) #> #> [1] /Users/lalexander/Documents/r_libs #> [2] /Library/Frameworks/R.framework/Versions/3.5/Resources/library ```
DavisVaughan commented 5 years ago

Haven't even looked at this but can't wait to show you where you are wrong.

olangfor commented 5 years ago

Thanks @exuberantleigh . Was literally about to post the same issue (using iris as an example lol).

Let me know if I've misunderstood how to use the add_role functionality in recipes.

Thanks for your great work

exuberantleigh commented 5 years ago

I think @DavisVaughan would be in a better position to say if there's anything funny/incorrect about the way roles are being specified via recipes, or if anything is currently being developed in recipes that would affect this issue. At least in this example, it matters whether you initially set up the recipe using a formula or not.

This issue specifically occurred when the caret code was checking for case weights. It would be a pretty straightforward fix so I should probably just get my act together and create a PR for it, instead of relying on the workaround I used in the reprex. Cuz that doesn't really help anybody else who encounters this issue :) .

DavisVaughan commented 5 years ago

The problem is coming from this line in train.recipe() https://github.com/topepo/caret/blob/7dbf54c8e378295f38eb378c960fb89473eff152/pkg/caret/R/train.default.R#L1019

is_weight <- summary(trained_rec)$role == "case weight"
# [1] FALSE FALSE    NA    NA    NA

The any(is_weight) check later then returns NA and you can't do if(NA).

There are a couple ways to handle this. You could use any(is_weight, na.rm = TRUE), but I think it might be better if is_weight held the correct values (I expect the NA values to be FALSE).

So you might just replace the == with %in%, which would give the expected results.

You'll also have to do it a bit further down, here: https://github.com/topepo/caret/blob/7dbf54c8e378295f38eb378c960fb89473eff152/pkg/caret/R/train.default.R#L1027

@exuberantleigh if you want to try a PR for this, I'm happy to review (I'll be more responsive on that I promise). I haven't checked to see if there are any other bugs that will pop up with NA roles after we get past those two lines, but we can tackle them as they come up.