tidymodels / parsnip

A tidy unified interface to models
https://parsnip.tidymodels.org
Other
595 stars 88 forks source link

Error in cbind2(1, newx) %*% nbeta : invalid class 'NA' to dup_mMatrix_as_dgeMatrix #200

Closed konradsemsch closed 5 years ago

konradsemsch commented 5 years ago

I took the example listed in this blogpost and tried to replicate it using glmnet: https://www.alexpghayes.com/blog/implementing-the-super-learner-with-tidymodels/

I wanted to use binary classification so I excluded one of the factor levels, but otherwise changed as minimal as possible in order to run it. When I'm getting to the part when I want to make predictions on the split's assessment set I get the following error:

Error in cbind2(1, newx) %*% nbeta : 
  invalid class 'NA' to dup_mMatrix_as_dgeMatrix

More specifically it's breaking in this part when I'm trying to make predictions on the hold-out set:

en_fits_cv_pred <- en_fits_cv %>%
  mutate(
    preds = future_pmap(list(fit, splits, prepped), predict_helper)
  )

I was also trying to run the prediction using only 1 model fit to exclude the possibility of something breaking in the map, but the error perists:

predict(en_fits_cv$fit[[1]], new_data = juice(prep(en_rec, retain = TRUE)))

The full code I'm running is the following:

set.seed(42)

# Loading libraries -------------------------------------------------------

library(magrittr)
library(tidyverse)
library(tidymodels)
library(dials)
library(furrr)

# Loading input dataset ---------------------------------------------------

df_all <- iris %>% 
  filter(Species != "setosa") %>% 
  mutate(Species = factor(Species, levels = c("versicolor", "virginica")))

# Dividing the dataset ----------------------------------------------------

df_train_cv <- vfold_cv(df_all, v = 5, repeats = 1)

# Preparing the recipes ----------------------------------------------------

# I need to add a custom step over here on the missing patterns

en_rec <- df_all %>% 
  recipe(Species ~ .) %>% 
  step_pca(all_predictors(), num_comp = 2)

# Training models withing resamples ---------------------------------------

fit_on_fold <- function(spec, prepped) {

  x <- juice(prepped, all_predictors())
  y <- juice(prepped, all_outcomes())

  fit_xy(spec, x, y)
}

en_engine <- logistic_reg(mode = "classification") %>% 
  set_engine("glmnet")

en_grid <- grid_regular(penalty, mixture, levels = c(2, 2))

en_spec <- tibble(spec = merge(en_engine, en_grid)) %>%  # combining model engine with different parameters
  mutate(model_id = row_number())

en_spec_cv <- crossing(df_train_cv, en_spec) # adding cross-validated folds

en_fits_cv <- en_spec_cv %>% # fitting different model specifications to different folds
  mutate(
    prepped = future_map(splits, prepper, en_rec),
    fit = future_map2(spec, prepped, fit_on_fold)
  )

# Making holdout predictions ----------------------------------------------

predict_helper <- function(fit, new_data, recipe) {

  # new_data can either be an rsample::rsplit object
  # or a data frame of genuinely new data

  if (inherits(new_data, "rsplit")) {
    obs <- as.integer(new_data, data = "assessment")

    # never forget to bake when predicting with recipes!
    new_data <- bake(recipe, assessment(new_data))
  } else {
    obs <- 1:nrow(new_data)
    new_data <- bake(recipe, new_data)
  }

  # if you want to generalize this code to a regression
  # super learner, you'd need to set `type = "response"` here

  predict(fit, new_data, type = "prob") %>% 
    mutate(obs = obs)
}

en_fits_cv_pred <- en_fits_cv %>%
  mutate(
    preds = future_pmap(list(fit, splits, prepped), predict_helper)
  )

I've been looking for help around the internet but unfortunately I'm absolutely about where the root case could be. Could anyone assist?

My session info below:

> session_info()
─ Session info ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
 setting  value                       
 version  R version 3.6.0 (2019-04-26)
 os       macOS Mojave 10.14.5        
 system   x86_64, darwin15.6.0        
 ui       RStudio                     
 language (EN)                        
 collate  en_US.UTF-8                 
 ctype    en_US.UTF-8                 
 tz       Europe/Berlin               
 date     2019-07-31                  

─ Packages ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
 ! package       * version    date       lib source                             
   abind           1.4-5      2016-07-21 [1] CRAN (R 3.6.0)                     
   assertthat      0.2.1      2019-03-21 [1] CRAN (R 3.6.0)                     
   backports       1.1.4      2019-04-10 [1] CRAN (R 3.6.0)                     
   base64enc       0.1-3      2015-07-28 [1] CRAN (R 3.6.0)                     
   bayesplot       1.7.0      2019-05-23 [1] CRAN (R 3.6.0)                     
   bitops          1.0-6      2013-08-17 [1] CRAN (R 3.6.0)                     
   boot            1.3-22     2019-04-02 [1] CRAN (R 3.6.0)                     
   broom         * 0.5.2      2019-04-07 [1] CRAN (R 3.6.0)                     
   C50             0.1.2      2018-05-22 [1] CRAN (R 3.6.0)                     
   callr           3.3.0      2019-07-04 [1] CRAN (R 3.6.0)                     
   caret         * 6.0-84     2019-04-27 [1] CRAN (R 3.6.0)                     
   caTools         1.17.1.2   2019-03-06 [1] CRAN (R 3.6.0)                     
   cellranger      1.1.0      2016-07-27 [1] CRAN (R 3.6.0)                     
   class           7.3-15     2019-01-01 [1] CRAN (R 3.6.0)                     
   cli             1.1.0      2019-03-19 [1] CRAN (R 3.6.0)                     
   codetools       0.2-16     2018-12-24 [1] CRAN (R 3.6.0)                     
   colorspace      1.4-1      2019-03-18 [1] CRAN (R 3.6.0)                     
   colourpicker    1.0        2017-09-27 [1] CRAN (R 3.6.0)                     
   crayon          1.3.4      2017-09-16 [1] CRAN (R 3.6.0)                     
   crosstalk       1.0.0      2016-12-21 [1] CRAN (R 3.6.0)                     
   Cubist          0.2.2      2018-05-21 [1] CRAN (R 3.6.0)                     
   curl            3.3        2019-01-10 [1] CRAN (R 3.6.0)                     
   data.table      1.12.2     2019-04-07 [1] CRAN (R 3.6.0)                     
   desc            1.2.0      2018-05-01 [1] CRAN (R 3.6.0)                     
   devtools      * 2.0.2      2019-04-08 [1] CRAN (R 3.6.0)                     
   dials         * 0.0.2      2018-12-09 [1] CRAN (R 3.6.0)                     
   digest          0.6.20     2019-07-04 [1] CRAN (R 3.6.0)                     
   DMwR          * 0.4.1      2013-08-08 [1] CRAN (R 3.6.0)                     
   dplyr         * 0.8.1      2019-05-14 [1] CRAN (R 3.6.0)                     
   DT              0.7.1      2019-06-27 [1] Github (rstudio/DT@c6fd864)        
   dygraphs        1.1.1.6    2018-07-11 [1] CRAN (R 3.6.0)                     
   e1071           1.7-1      2019-03-19 [1] CRAN (R 3.6.0)                     
   fansi           0.4.0      2018-10-05 [1] CRAN (R 3.6.0)                     
   forcats       * 0.4.0      2019-02-17 [1] CRAN (R 3.6.0)                     
   foreach       * 1.4.4      2017-12-12 [1] CRAN (R 3.6.0)                     
   Formula         1.2-3      2018-05-03 [1] CRAN (R 3.6.0)                     
   fs              1.3.1      2019-05-06 [1] CRAN (R 3.6.0)                     
   furrr         * 0.1.0      2018-05-16 [1] CRAN (R 3.6.0)                     
   future        * 1.13.0     2019-05-08 [1] CRAN (R 3.6.0)                     
   gdata           2.18.0     2017-06-06 [1] CRAN (R 3.6.0)                     
   generics        0.0.2      2018-11-29 [1] CRAN (R 3.6.0)                     
   ggplot2       * 3.2.0      2019-06-16 [1] CRAN (R 3.6.0)                     
   ggridges        0.5.1      2018-09-27 [1] CRAN (R 3.6.0)                     
   glmnet        * 2.0-18     2019-05-20 [1] CRAN (R 3.6.0)                     
   globals         0.12.4     2018-10-11 [1] CRAN (R 3.6.0)                     
   glue            1.3.1      2019-03-12 [1] CRAN (R 3.6.0)                     
   gower           0.2.1      2019-05-14 [1] CRAN (R 3.6.0)                     
   gplots          3.0.1.1    2019-01-27 [1] CRAN (R 3.6.0)                     
   gridExtra       2.3        2017-09-09 [1] CRAN (R 3.6.0)                     
   gtable          0.3.0      2019-03-25 [1] CRAN (R 3.6.0)                     
   gtools          3.8.1      2018-06-26 [1] CRAN (R 3.6.0)                     
   haven           2.1.0      2019-02-19 [1] CRAN (R 3.6.0)                     
   hms             0.4.2      2018-03-10 [1] CRAN (R 3.6.0)                     
   htmltools       0.3.6      2017-04-28 [1] CRAN (R 3.6.0)                     
   htmlwidgets     1.3        2018-09-30 [1] CRAN (R 3.6.0)                     
   httpuv          1.5.1      2019-04-05 [1] CRAN (R 3.6.0)                     
   httr            1.4.0      2018-12-11 [1] CRAN (R 3.6.0)                     
   igraph          1.2.4.1    2019-04-22 [1] CRAN (R 3.6.0)                     
   infer         * 0.4.0.1    2019-04-22 [1] CRAN (R 3.6.0)                     
   inline          0.3.15     2018-05-18 [1] CRAN (R 3.6.0)                     
   inum            1.0-1      2019-04-25 [1] CRAN (R 3.6.0)                     
   ipred           0.9-9      2019-04-28 [1] CRAN (R 3.6.0)                     
   iterators       1.0.10     2018-07-13 [1] CRAN (R 3.6.0)                     
   janeaustenr     0.1.5      2017-06-10 [1] CRAN (R 3.6.0)                     
   jsonlite        1.6        2018-12-07 [1] CRAN (R 3.6.0)                     
   KernSmooth      2.23-15    2015-06-29 [1] CRAN (R 3.6.0)                     
   knitr           1.22       2019-03-08 [1] CRAN (R 3.6.0)                     
   later           0.8.0      2019-02-11 [1] CRAN (R 3.6.0)                     
   lattice       * 0.20-38    2018-11-04 [1] CRAN (R 3.6.0)                     
   lava            1.6.5      2019-02-12 [1] CRAN (R 3.6.0)                     
   lazyeval        0.2.2      2019-03-15 [1] CRAN (R 3.6.0)                     
   libcoin         1.0-4      2019-02-28 [1] CRAN (R 3.6.0)                     
   listenv         0.7.0      2018-01-21 [1] CRAN (R 3.6.0)                     
   lme4            1.1-21     2019-03-05 [1] CRAN (R 3.6.0)                     
   loo             2.1.0      2019-03-13 [1] CRAN (R 3.6.0)                     
   lubridate       1.7.4      2018-04-11 [1] CRAN (R 3.6.0)                     
   magrittr      * 1.5        2014-11-22 [1] CRAN (R 3.6.0)                     
   markdown        0.9        2018-12-07 [1] CRAN (R 3.6.0)                     
   MASS            7.3-51.4   2019-03-31 [1] CRAN (R 3.6.0)                     
   Matrix        * 1.2-17     2019-03-22 [1] CRAN (R 3.6.0)                     
   matrixStats     0.54.0     2018-07-23 [1] CRAN (R 3.6.0)                     
   memoise         1.1.0      2017-04-21 [1] CRAN (R 3.6.0)                     
   mime            0.7        2019-06-11 [1] CRAN (R 3.6.0)                     
   miniUI          0.1.1.1    2018-05-18 [1] CRAN (R 3.6.0)                     
   minqa           1.2.4      2014-10-09 [1] CRAN (R 3.6.0)                     
   MLmetrics       1.1.1      2016-05-13 [1] CRAN (R 3.6.0)                     
   ModelMetrics    1.2.2      2018-11-03 [1] CRAN (R 3.6.0)                     
   modelr          0.1.4      2019-02-18 [1] CRAN (R 3.6.0)                     
   munsell         0.5.0      2018-06-12 [1] CRAN (R 3.6.0)                     
   mvtnorm         1.0-10     2019-03-05 [1] CRAN (R 3.6.0)                     
   naniar          0.4.2      2019-02-15 [1] CRAN (R 3.6.0)                     
   nlme            3.1-139    2019-04-09 [1] CRAN (R 3.6.0)                     
   nloptr          1.2.1      2018-10-03 [1] CRAN (R 3.6.0)                     
   nnet            7.3-12     2016-02-02 [1] CRAN (R 3.6.0)                     
   packrat         0.5.0      2018-11-14 [1] CRAN (R 3.6.0)                     
 V parsnip       * 0.0.2      2019-07-31 [1] Github (tidymodels/parsnip@54fcc26)
   partykit        1.2-3      2019-01-31 [1] CRAN (R 3.6.0)                     
   pillar          1.4.2      2019-06-29 [1] CRAN (R 3.6.0)                     
   pkgbuild        1.0.3      2019-03-20 [1] CRAN (R 3.6.0)                     
   pkgconfig       2.0.2      2018-08-16 [1] CRAN (R 3.6.0)                     
   pkgload         1.0.2      2018-10-29 [1] CRAN (R 3.6.0)                     
   plyr            1.8.4      2016-06-08 [1] CRAN (R 3.6.0)                     
   prettyunits     1.0.2      2015-07-13 [1] CRAN (R 3.6.0)                     
   pROC            1.14.0     2019-03-12 [1] CRAN (R 3.6.0)                     
   processx        3.4.0      2019-07-03 [1] CRAN (R 3.6.0)                     
   prodlim         2018.04.18 2018-04-18 [1] CRAN (R 3.6.0)                     
   promises        1.0.1      2018-04-13 [1] CRAN (R 3.6.0)                     
   ps              1.3.0      2018-12-21 [1] CRAN (R 3.6.0)                     
   purrr         * 0.3.2      2019-03-15 [1] CRAN (R 3.6.0)                     
   quantmod        0.4-14     2019-03-24 [1] CRAN (R 3.6.0)                     
   R6              2.4.0      2019-02-14 [1] CRAN (R 3.6.0)                     
   Rcpp            1.0.1      2019-03-17 [1] CRAN (R 3.6.0)                     
   readr         * 1.3.1      2018-12-21 [1] CRAN (R 3.6.0)                     
   readxl          1.3.1      2019-03-13 [1] CRAN (R 3.6.0)                     
   recipes       * 0.1.6      2019-07-02 [1] CRAN (R 3.6.0)                     
   remotes         2.0.4      2019-04-10 [1] CRAN (R 3.6.0)                     
   reshape2        1.4.3      2017-12-11 [1] CRAN (R 3.6.0)                     
   rlang           0.4.0      2019-06-25 [1] CRAN (R 3.6.0)                     
   ROCR            1.0-7      2015-03-26 [1] CRAN (R 3.6.0)                     
   rpart           4.1-15     2019-04-12 [1] CRAN (R 3.6.0)                     
   rprojroot       1.3-2      2018-01-03 [1] CRAN (R 3.6.0)                     
   rsample       * 0.0.5      2019-07-12 [1] CRAN (R 3.6.0)                     
   rsconnect       0.8.14     2019-05-03 [1] Github (rstudio/rsconnect@83f3bd7) 
   rstan           2.19.2     2019-07-09 [1] CRAN (R 3.6.0)                     
   rstanarm        2.18.2     2018-11-10 [1] CRAN (R 3.6.0)                     
   rstantools      1.5.1      2018-08-22 [1] CRAN (R 3.6.0)                     
   rstudioapi      0.10       2019-03-19 [1] CRAN (R 3.6.0)                     
   rvest           0.3.3      2019-04-11 [1] CRAN (R 3.6.0)                     
   scales        * 1.0.0      2018-08-09 [1] CRAN (R 3.6.0)                     
   sessioninfo     1.1.1      2018-11-05 [1] CRAN (R 3.6.0)                     
   shiny           1.3.2      2019-04-22 [1] CRAN (R 3.6.0)                     
   shinyjs         1.0        2018-01-08 [1] CRAN (R 3.6.0)                     
   shinystan       2.5.0      2018-05-01 [1] CRAN (R 3.6.0)                     
   shinythemes     1.1.2      2018-11-06 [1] CRAN (R 3.6.0)                     
   SnowballC       0.6.0      2019-01-15 [1] CRAN (R 3.6.0)                     
   StanHeaders     2.18.1-10  2019-06-14 [1] CRAN (R 3.6.0)                     
   stringi         1.4.3      2019-03-12 [1] CRAN (R 3.6.0)                     
   stringr       * 1.4.0      2019-02-10 [1] CRAN (R 3.6.0)                     
   survival        2.44-1.1   2019-04-01 [1] CRAN (R 3.6.0)                     
   testthat        2.1.1      2019-04-23 [1] CRAN (R 3.6.0)                     
   threejs         0.3.1      2017-08-13 [1] CRAN (R 3.6.0)                     
   tibble        * 2.1.3      2019-06-06 [1] CRAN (R 3.6.0)                     
   tidymodels    * 0.0.2      2018-11-27 [1] CRAN (R 3.6.0)                     
   tidyposterior   0.0.2      2018-11-15 [1] CRAN (R 3.6.0)                     
   tidypredict     0.4.2      2019-07-15 [1] CRAN (R 3.6.0)                     
   tidyr         * 0.8.3      2019-03-01 [1] CRAN (R 3.6.0)                     
   tidyselect      0.2.5      2018-10-11 [1] CRAN (R 3.6.0)                     
   tidytext        0.2.1      2019-06-14 [1] CRAN (R 3.6.0)                     
   tidyverse     * 1.2.1      2017-11-14 [1] CRAN (R 3.6.0)                     
   timeDate        3043.102   2018-02-21 [1] CRAN (R 3.6.0)                     
   tokenizers      0.2.1      2018-03-29 [1] CRAN (R 3.6.0)                     
   TTR             0.23-4     2018-09-20 [1] CRAN (R 3.6.0)                     
   usethis       * 1.5.0      2019-04-07 [1] CRAN (R 3.6.0)                     
   utf8            1.1.4      2018-05-24 [1] CRAN (R 3.6.0)                     
   vctrs           0.2.0      2019-07-05 [1] CRAN (R 3.6.0)                     
   visdat          0.5.3      2019-02-15 [1] CRAN (R 3.6.0)                     
   withr           2.1.2      2018-03-15 [1] CRAN (R 3.6.0)                     
   xfun            0.6        2019-04-02 [1] CRAN (R 3.6.0)                     
   xgboost         0.90.0.1   2019-07-25 [1] CRAN (R 3.6.0)                     
   xml2            1.2.0      2018-01-24 [1] CRAN (R 3.6.0)                     
   xtable          1.8-4      2019-04-21 [1] CRAN (R 3.6.0)                     
   xts             0.11-2     2018-11-05 [1] CRAN (R 3.6.0)                     
   yardstick     * 0.0.3      2019-03-08 [1] CRAN (R 3.6.0)                     
   zeallot         0.1.0      2018-01-28 [1] CRAN (R 3.6.0)                     
   zoo             1.8-5      2019-03-21 [1] CRAN (R 3.6.0)                     

[1] /Library/Frameworks/R.framework/Versions/3.6/Resources/library

 V ── Loaded and on-disk version mismatch.
topepo commented 5 years ago

Honestly, I have not idea. I rewrote the prediction helper function to be a little more simple and rearranged the arguments (odc :-/). I also added a performance metric below too.

We're working on model tuning right now that will make this a lot easier. The use of crossing() is fine but you probably won't have to do that once we have the better api in place.

set.seed(42)

# Loading libraries -------------------------------------------------------

library(magrittr)
library(tidyverse)
#> Registered S3 method overwritten by 'rvest':
#>   method            from
#>   read_xml.response xml2
library(tidymodels)
#> ── Attaching packages ──────────────────────────────────────────────────────── tidymodels 0.0.2 ──
#> ✔ broom     0.5.2       ✔ recipes   0.1.6  
#> ✔ dials     0.0.2       ✔ rsample   0.0.5  
#> ✔ infer     0.4.0.1     ✔ yardstick 0.0.3  
#> ✔ parsnip   0.0.3
#> ── Conflicts ─────────────────────────────────────────────────────────── tidymodels_conflicts() ──
#> ✖ scales::discard()  masks purrr::discard()
#> ✖ tidyr::extract()   masks magrittr::extract()
#> ✖ dplyr::filter()    masks stats::filter()
#> ✖ recipes::fixed()   masks stringr::fixed()
#> ✖ dplyr::lag()       masks stats::lag()
#> ✖ purrr::set_names() masks magrittr::set_names()
#> ✖ yardstick::spec()  masks readr::spec()
#> ✖ recipes::step()    masks stats::step()
library(dials)
library(furrr)
#> Loading required package: future

# Loading input dataset ---------------------------------------------------

df_all <- iris %>% 
  filter(Species != "setosa") %>% 
  mutate(Species = factor(Species, levels = c("versicolor", "virginica")))

# Dividing the dataset ----------------------------------------------------

df_train_cv <- vfold_cv(df_all, v = 5, repeats = 1)

# Preparing the recipes ----------------------------------------------------

# I need to add a custom step over here on the missing patterns

en_rec <- df_all %>% 
  recipe(Species ~ .) %>% 
  step_pca(all_predictors(), num_comp = 2)

# Training models withing resamples ---------------------------------------

fit_on_fold <- function(spec, prepped) {

  x <- juice(prepped, all_predictors())
  y <- juice(prepped, all_outcomes())

  fit_xy(spec, x, y)
}

en_engine <- logistic_reg(mode = "classification") %>% 
  set_engine("glmnet")

en_grid <- grid_regular(penalty, mixture, levels = c(2, 2))

en_spec <- tibble(spec = merge(en_engine, en_grid)) %>%  # combining model engine with different parameters
  mutate(model_id = row_number())

en_spec_cv <- crossing(df_train_cv, en_spec) # adding cross-validated folds

en_fits_cv <- en_spec_cv %>% # fitting different model specifications to different folds
  mutate(
    prepped = future_map(splits, prepper, en_rec),
    fit = future_map2(spec, prepped, fit_on_fold)
  )

predict_helper <- function(split, recipe, fit) {

  new_x <- bake(recipe, new_data = assessment(split), all_predictors())

  predict(fit, new_x, type = "prob") %>% 
    bind_cols(assessment(split) %>% select(Species)) 
}

en_fits_cv_pred <- en_fits_cv %>%
  mutate(
    preds = future_pmap(list(splits, prepped, fit), predict_helper)
  )

indiv_estimates <- 
  en_fits_cv_pred %>% 
  unnest(preds) %>% 
  group_by(id, model_id) %>% 
  # or some other performance measure:
  mn_log_loss(truth = Species, .pred_virginica)

rs_estimates <- 
  indiv_estimates %>% 
  group_by(model_id, .metric, .estimator) %>% 
  summarize(mean = mean(.estimate, na.rm = TRUE))

rs_estimates
#> # A tibble: 4 x 4
#> # Groups:   model_id, .metric [4]
#>   model_id .metric     .estimator  mean
#>      <int> <chr>       <chr>      <dbl>
#> 1        1 mn_log_loss binary     2.36 
#> 2        2 mn_log_loss binary     0.938
#> 3        3 mn_log_loss binary     9.45 
#> 4        4 mn_log_loss binary     0.691

Created on 2019-07-31 by the reprex package (v0.2.1)

konradsemsch commented 5 years ago

Thanks @topepo for taking a look!

github-actions[bot] commented 3 years ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.