tidymodels / poissonreg

parsnip wrappers for Poisson regression
https://poissonreg.tidymodels.org
Other
22 stars 4 forks source link

multi_predict fails for glmnet models when generating predictions for only a single row #48

Closed joranE closed 1 year ago

joranE commented 1 year ago

multi_predict() fails on glmnet models when passing only a single row of data.

library(tidyverse)
library(tidymodels)
library(poissonreg)
# This works:
> m <- poisson_reg() %>% 
+   set_engine("glmnet") %>% 
+   fit(count ~ (.)^2, data = seniors[,2:4])
> 
> multi_predict(m,new_data = seniors[1:2,])
# A tibble: 2 × 1
  .pred            
  <list>           
1 <tibble [72 × 2]>
2 <tibble [72 × 2]>

versus...

# This does not:
> m <- poisson_reg() %>% 
+   set_engine("glmnet") %>% 
+   fit(count ~ (.)^2, data = seniors[,2:4])
> 
> multi_predict(m,new_data = seniors[1,])
Error in predict.glmnet(object = object$fit, newx = as.matrix(new_data[,  : 
  The number of variables in newx must be 3

Session info:

R version 4.1.2 (2021-11-01)
Platform: aarch64-apple-darwin20 (64-bit)
Running under: macOS Monterey 12.0.1

Matrix products: default
LAPACK: /Library/Frameworks/R.framework/Versions/4.1-arm64/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] poissonreg_1.0.1   yardstick_1.0.0    workflowsets_1.0.0 workflows_1.0.0    tune_1.0.0        
 [6] rsample_1.1.0      recipes_1.0.3      parsnip_1.0.0      modeldata_1.0.1    infer_1.0.4       
[11] dials_1.0.0        scales_1.2.1       broom_1.0.2        tidymodels_1.0.0   forcats_0.5.1     
[16] stringr_1.4.0      dplyr_1.0.9        purrr_0.3.4        readr_2.1.0        tidyr_1.2.0       
[21] tibble_3.1.7       ggplot2_3.4.0      tidyverse_1.3.1   

loaded via a namespace (and not attached):
 [1] fs_1.5.0                lubridate_1.8.0         DiceDesign_1.9          httr_1.4.2             
 [5] tools_4.1.2             backports_1.3.0         utf8_1.2.2              R6_2.5.1               
 [9] rpart_4.1-15            DBI_1.1.1               colorspace_2.0-2        nnet_7.3-16            
[13] withr_2.5.0             tidyselect_1.2.0        compiler_4.1.2          glmnet_4.1-6           
[17] cli_3.4.1               rvest_1.0.2             xml2_1.3.2              digest_0.6.28          
[21] pkgconfig_2.0.3         parallelly_1.28.1       lhs_1.1.3               dbplyr_2.1.1           
[25] rlang_1.0.6             readxl_1.3.1            rstudioapi_0.13         shape_1.4.6            
[29] generics_0.1.2          jsonlite_1.7.2          magrittr_2.0.1          Matrix_1.3-4           
[33] Rcpp_1.0.7              munsell_0.5.0           fansi_0.5.0             GPfit_1.0-8            
[37] lifecycle_1.0.3         furrr_0.2.3             stringi_1.7.5           MASS_7.3-54            
[41] grid_4.1.2              parallel_4.1.2          listenv_0.8.0           crayon_1.4.2           
[45] lattice_0.20-45         haven_2.4.3             splines_4.1.2           hms_1.1.1              
[49] pillar_1.7.0            future.apply_1.8.1-9001 codetools_0.2-18        reprex_2.0.1           
[53] glue_1.6.2              modelr_0.1.8            vctrs_0.5.1             tzdb_0.3.0             
[57] foreach_1.5.1           cellranger_1.1.0        gtable_0.3.0            future_1.23.0          
[61] assertthat_0.2.1        gower_0.2.2             prodlim_2019.11.13      class_7.3-19           
[65] survival_3.2-13         timeDate_3043.102       iterators_1.0.13        hardhat_1.2.0          
[69] lava_1.6.10             globals_0.14.0          ellipsis_0.3.2          ipred_0.9-12   
hfrick commented 1 year ago

Thanks for the report, @joranE ! That's a bug, the dimensions get dropped and incorrectly reconstructed here:

https://github.com/tidymodels/poissonreg/blob/dff0e25c69fcec662b60640b663ca9a6fcaa72ae/R/poisson_reg_data.R#L269

hfrick commented 1 year ago

This also affects predictions for a single penalty value

library(poissonreg)
#> Loading required package: parsnip

m <- poisson_reg() %>% 
  set_engine("glmnet") %>% 
  fit(count ~ (.)^2, data = seniors[,2:4])

multi_predict(m, new_data = seniors[1:2,])
#> # A tibble: 2 × 1
#>   .pred            
#>   <list>           
#> 1 <tibble [72 × 2]>
#> 2 <tibble [72 × 2]>
multi_predict(m, new_data = seniors[1,])
#> Error in predict.glmnet(object = object$fit, newx = as.matrix(new_data[, : The number of variables in newx must be 3

predict(m, new_data = seniors[1:2,], penalty = 0.1)
#> # A tibble: 2 × 1
#>   .pred
#>   <dbl>
#> 1  724.
#> 2  724.
predict(m, new_data = seniors[1,], penalty = 0.1)
#> Error in predict.glmnet(object = object$fit, newx = as.matrix(new_data[, : The number of variables in newx must be 3

Created on 2022-12-19 with reprex v2.0.2

github-actions[bot] commented 1 year ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.