tidymodels / poissonreg

parsnip wrappers for Poisson regression
https://poissonreg.tidymodels.org
Other
22 stars 4 forks source link

Cannot predict with glmnet engine #4

Closed nlubock closed 3 years ago

nlubock commented 4 years ago

Hello,

I'm having a hard time making predictions with poisson models fit using glmnet, whereas glm works fine. I'm new to the tidymodels universe so forgive me if I've overlooked something trivial...

poisson_reg() %>%
  set_engine("glm") %>%
  fit(count ~ (.)^2, data = seniors) %>%
  predict(new_data=seniors)

#> # A tibble: 8 x 1
#>    .pred
#>    <dbl>
#> 1 910.  
#> 2 539.  
#> 3  44.6 
#> 4 455.  
#> 5   3.62
#> 6  42.4 
#> 7   1.38
#> 8 280. 

Same thing but with glmnet

library(poissonreg)

poisson_reg() %>% 
  set_engine("glmnet") %>% 
  fit(count ~ (.)^2, data = seniors) %>%
  predict(new_data=seniors)

#> Error: `penalty` should be a single numeric value. `multi_predict()` can be used to get multiple predictions per row of data.
#> Run `rlang::last_error()` to see where the error occurred.

and with multi_predict()

poisson_reg() %>% 
  set_engine("glmnet") %>% 
  fit(count ~ (.)^2, data = seniors) %>%
  multi_predict(new_data=seniors)

#> Error in class() : 0 argument passed to 'class' which requires 1

sessionInfo() output:

R version 3.6.3 (2020-02-29)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Debian GNU/Linux 10 (buster)

Matrix products: default
BLAS/LAPACK: /usr/lib/x86_64-linux-gnu/libopenblasp-r0.3.5.so

locale:
 [1] LC_CTYPE=C.UTF-8       LC_NUMERIC=C           LC_TIME=C.UTF-8        LC_COLLATE=C.UTF-8     LC_MONETARY=C.UTF-8   
 [6] LC_MESSAGES=C          LC_PAPER=C.UTF-8       LC_NAME=C              LC_ADDRESS=C           LC_TELEPHONE=C        
[11] LC_MEASUREMENT=C.UTF-8 LC_IDENTIFICATION=C   

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] reprex_0.3.0          poissonreg_0.0.1.9000 parsnip_0.1.1        

loaded via a namespace (and not attached):
 [1] Rcpp_1.0.4.6      pillar_1.4.4      compiler_3.6.3    prettyunits_1.1.1 iterators_1.0.12  tools_3.6.3       digest_0.6.25    
 [8] evaluate_0.14     lifecycle_0.2.0   tibble_3.0.1      lattice_0.20-38   pkgconfig_2.0.3   rlang_0.4.6       Matrix_1.2-18    
[15] foreach_1.5.0     cli_2.0.2         rstudioapi_0.11   xfun_0.12         knitr_1.28        dplyr_0.8.5       generics_0.0.2   
[22] vctrs_0.3.0       fs_1.4.0          glmnet_4.0        grid_3.6.3        tidyselect_1.1.0  glue_1.4.1        R6_2.4.1         
[29] processx_3.4.2    fansi_0.4.1       rmarkdown_2.1     callr_3.4.3       whisker_0.4       tidyr_1.0.3       purrr_0.3.4      
[36] clipr_0.7.0       magrittr_1.5      ps_1.3.2          htmltools_0.4.0   codetools_0.2-16  ellipsis_0.3.1    assertthat_0.2.1 
[43] shape_1.4.4       utf8_1.1.4        crayon_1.3.4  
topepo commented 4 years ago

You need to give it a value of penalty (as weird as that seems for glmnet). We'll put a better error message in.

nlubock commented 4 years ago

Gotcha, thanks!

github-actions[bot] commented 3 years ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.