tidymodels / tune

Tools for tidy parameter tuning
https://tune.tidymodels.org
Other
273 stars 42 forks source link

logistic regression with glmnet engine when penalty = 0 #591

Open GilHenriques opened 1 year ago

GilHenriques commented 1 year ago

The problem

When implementing a logistic regression with glmnet, I encounter two issues that I believe to be related. The reproducible example below showcases both issues. The issues arise when (as a reliability check), I set penalty = 0. The purpose of the check was to confirm that mixture has no effect when penalty = 0).

In short, the issues are:

  1. Even though a penalty value is explicitly provided in the model specification -- logistic_reg(penalty = 0, mixture = tune()) -- I get a "no_penalty()" error when tuning the workflow. This error is also obtained for values of penalty different from zero.
  2. In an effort to avoid this error, I set penalty = tune() and then include penalty = 0 in my tuning grid. The code then runs, but contrary to my expectation, the mixture had an effect on accuracy and ROC AUC.
  3. When I implement a similar model directly in the glmnet package, I confirm that when lambda = 0 (no penalty), there is no effect of alpha (mixture), whereas when lambda is larger than zero, there is an effect of alpha. This appears inconsistent with point 2 above.

Reproducible example

``` r
library(tidyverse)
library(tidymodels)

set.seed(123)

# Create an example data frame
df <- tibble(Y = sample(c(1, 0), 1000, replace = TRUE),
       X1 = rnorm(1000),
       X2 = rnorm(1000),
       X3 = rnorm(1000),
       X4 = rnorm(1000)) |> 
  mutate(Y = factor(Y))

# Initial split
splits <- initial_split(df)
train <- training(splits)
folds <- vfold_cv(train)

# Issue 1: No penalty error, even though a penalty is specified
model <- logistic_reg(penalty = 0, mixture = tune())|> set_engine('glmnet')
rec <- recipe(Y ~ ., data = train)
wflow <- workflow() |> add_model(model) |> add_recipe(rec)

wflow |> tune_grid(folds)
#> Error in `no_penalty()`:
#> ! At least one penalty value is required for glmnet.

#> Backtrace:
#>      ▆
#>   1. ├─tune::tune_grid(wflow, folds)
#>   2. └─tune:::tune_grid.workflow(wflow, folds)
#>   3.   └─tune:::tune_grid_workflow(...)
#>   4.     └─tune:::tune_grid_loop(...)
#>   5.       └─tune (local) fn_tune_grid_loop(...)
#>   6.         └─tune:::tune_grid_loop_impl(...)
#>   7.           └─tune:::compute_grid_info(workflow, grid)
#>   8.             └─tune:::compute_grid_info_model(workflow, grid, parameters_model)
#>   9.               ├─generics::min_grid(spec, grid)
#>  10.               └─tune::min_grid.logistic_reg(spec, grid)
#>  11.                 └─tune:::no_penalty(grid, sub_nm)
#>  12.                   └─rlang::abort("At least one penalty value is required for glmnet.")
# Error in `no_penalty()`:
# ! At least one penalty value is required for glmnet

# Issue 2: If penalty = 0 in the tuning grid, mixture still has an effect
model <- logistic_reg(penalty = tune(), mixture = tune())|> set_engine('glmnet')
rec <- recipe(Y ~ ., data = train)
wflow <- workflow() |> add_model(model) |> add_recipe(rec)

reg_grid <- expand_grid(penalty = 0, mixture = c(0.001, 0.01, 0.1, 0.25, 0.5, 0.6))

wflow |> tune_grid(folds, grid = reg_grid) |> 
  autoplot() # Parameter makes a difference even though penalty = 0


# Issue 3: When we use glmnet directly, if lambda = 0 alpha makes no difference
X <- df[1:500,-1] |> as.matrix()
Y <- df[1:500,] |> pull(Y)
fit1 <- glmnet::glmnet(X, Y, family = 'binomial', lambda = 0, alpha = 0.001)
fit2 <- glmnet::glmnet(X, Y, family = 'binomial', lambda = 0, alpha = 0.1)
fit3 <-  glmnet::glmnet(X, Y, family = 'binomial', lambda = 0, alpha = 0.5)

pred1 <- predict(fit1, newx = as.matrix(df[500:1000,-1]), type = 'class') |> as.vector()
pred2 <- predict(fit2, newx = as.matrix(df[500:1000,-1]), type = 'class') |> as.vector()
pred3 <- predict(fit3, newx = as.matrix(df[500:1000,-1]), type = 'class') |> as.vector()

tibble((df[500:1000,1]), pred1, pred2, pred3) |> 
  mutate(Y = as.character(Y)) |> 
  summarize(accuracy1 = sum(pred1 == Y)/n(),
            accuracy2 = sum(pred2 == Y)/n(),
            accuracy3 = sum(pred3 == Y)/n())
#> # A tibble: 1 × 3
#>   accuracy1 accuracy2 accuracy3
#>       <dbl>     <dbl>     <dbl>
#> 1     0.507     0.507     0.507

# ... But if lambda > 0 alpha does make a difference
X <- df[1:500,-1] |> as.matrix()
Y <- df[1:500,] |> pull(Y)
fit1 <- glmnet::glmnet(X, Y, family = 'binomial', lambda = 0.1, alpha = 0.001)
fit2 <- glmnet::glmnet(X, Y, family = 'binomial', lambda = 0.1, alpha = 0.1)
fit3 <-  glmnet::glmnet(X, Y, family = 'binomial', lambda = 0.1, alpha = 0.5)

pred1 <- predict(fit1, newx = as.matrix(df[500:1000,-1]), type = 'class') |> as.vector()
pred2 <- predict(fit2, newx = as.matrix(df[500:1000,-1]), type = 'class') |> as.vector()
pred3 <- predict(fit3, newx = as.matrix(df[500:1000,-1]), type = 'class') |> as.vector()

tibble((df[500:1000,1]), pred1, pred2, pred3) |> 
  mutate(Y = as.character(Y)) |> 
  summarize(accuracy1 = sum(pred1 == Y)/n(),
            accuracy2 = sum(pred2 == Y)/n(),
            accuracy3 = sum(pred3 == Y)/n())
#> # A tibble: 1 × 3
#>   accuracy1 accuracy2 accuracy3
#>       <dbl>     <dbl>     <dbl>
#> 1     0.509     0.489     0.491

Created on 2022-12-09 with reprex v2.0.2

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.2.0 (2022-04-22) #> os macOS Monterey 12.6 #> system aarch64, darwin20 #> ui X11 #> language (EN) #> collate en_US.UTF-8 #> ctype en_US.UTF-8 #> tz Europe/Stockholm #> date 2022-12-09 #> pandoc 2.19.2 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/tools/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date (UTC) lib source #> assertthat 0.2.1 2019-03-21 [1] CRAN (R 4.2.0) #> backports 1.4.1 2021-12-13 [1] CRAN (R 4.2.0) #> broom * 1.0.1 2022-08-29 [1] CRAN (R 4.2.0) #> cellranger 1.1.0 2016-07-27 [1] CRAN (R 4.2.0) #> class 7.3-20 2022-01-16 [1] CRAN (R 4.2.0) #> cli 3.4.1 2022-09-23 [1] CRAN (R 4.2.0) #> codetools 0.2-18 2020-11-04 [1] CRAN (R 4.2.0) #> colorspace 2.0-3 2022-02-21 [1] CRAN (R 4.2.0) #> crayon 1.5.2 2022-09-29 [1] CRAN (R 4.2.0) #> curl 4.3.2 2021-06-23 [1] CRAN (R 4.2.0) #> DBI 1.1.2 2021-12-20 [1] CRAN (R 4.2.0) #> dbplyr 2.2.0 2022-06-05 [1] CRAN (R 4.2.0) #> dials * 1.0.0 2022-06-14 [1] CRAN (R 4.2.0) #> DiceDesign 1.9 2021-02-13 [1] CRAN (R 4.2.0) #> digest 0.6.29 2021-12-01 [1] CRAN (R 4.2.0) #> dplyr * 1.0.10 2022-09-01 [1] CRAN (R 4.2.0) #> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.2.0) #> evaluate 0.16 2022-08-09 [1] CRAN (R 4.2.0) #> fansi 1.0.3 2022-03-24 [1] CRAN (R 4.2.0) #> farver 2.1.1 2022-07-06 [1] CRAN (R 4.2.0) #> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.2.0) #> forcats * 0.5.2 2022-08-19 [1] CRAN (R 4.2.0) #> foreach 1.5.2 2022-02-02 [1] CRAN (R 4.2.0) #> fs 1.5.2 2021-12-08 [1] CRAN (R 4.2.0) #> furrr 0.3.1 2022-08-15 [1] CRAN (R 4.2.0) #> future 1.27.0 2022-07-22 [1] CRAN (R 4.2.0) #> future.apply 1.9.0 2022-04-25 [1] CRAN (R 4.2.0) #> gargle 1.2.0 2021-07-02 [1] CRAN (R 4.2.0) #> generics 0.1.3 2022-07-05 [1] CRAN (R 4.2.0) #> ggplot2 * 3.3.6 2022-05-03 [1] CRAN (R 4.2.0) #> glmnet * 4.1-4 2022-04-15 [1] CRAN (R 4.2.0) #> globals 0.15.1 2022-06-24 [1] CRAN (R 4.2.0) #> glue 1.6.2 2022-02-24 [1] CRAN (R 4.2.0) #> googledrive 2.0.0 2021-07-08 [1] CRAN (R 4.2.0) #> googlesheets4 1.0.0 2021-07-21 [1] CRAN (R 4.2.0) #> gower 1.0.0 2022-02-03 [1] CRAN (R 4.2.0) #> GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.2.0) #> gtable 0.3.1 2022-09-01 [1] CRAN (R 4.2.0) #> hardhat 1.2.0 2022-06-30 [1] CRAN (R 4.2.0) #> haven 2.5.1 2022-08-22 [1] CRAN (R 4.2.0) #> highr 0.9 2021-04-16 [1] CRAN (R 4.2.0) #> hms 1.1.2 2022-08-19 [1] CRAN (R 4.2.0) #> htmltools 0.5.3 2022-07-18 [1] CRAN (R 4.2.0) #> httr 1.4.3 2022-05-04 [1] CRAN (R 4.2.0) #> infer * 1.0.2 2022-06-10 [1] CRAN (R 4.2.0) #> ipred 0.9-13 2022-06-02 [1] CRAN (R 4.2.0) #> iterators 1.0.14 2022-02-05 [1] CRAN (R 4.2.0) #> jsonlite 1.8.2 2022-10-02 [1] CRAN (R 4.2.0) #> knitr 1.40 2022-08-24 [1] CRAN (R 4.2.0) #> labeling 0.4.2 2020-10-20 [1] CRAN (R 4.2.0) #> lattice 0.20-45 2021-09-22 [1] CRAN (R 4.2.0) #> lava 1.6.10 2021-09-02 [1] CRAN (R 4.2.0) #> lhs 1.1.5 2022-03-22 [1] CRAN (R 4.2.0) #> lifecycle 1.0.2 2022-09-09 [1] CRAN (R 4.2.0) #> listenv 0.8.0 2019-12-05 [1] CRAN (R 4.2.0) #> lubridate 1.8.0 2021-10-07 [1] CRAN (R 4.2.0) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.2.0) #> MASS 7.3-56 2022-03-23 [1] CRAN (R 4.2.0) #> Matrix * 1.5-1 2022-09-13 [1] CRAN (R 4.2.0) #> mime 0.12 2021-09-28 [1] CRAN (R 4.2.0) #> modeldata * 1.0.0 2022-07-01 [1] CRAN (R 4.2.0) #> modelr 0.1.8 2020-05-19 [1] CRAN (R 4.2.0) #> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.2.0) #> nnet 7.3-17 2022-01-13 [1] CRAN (R 4.2.0) #> parallelly 1.32.1 2022-07-21 [1] CRAN (R 4.2.0) #> parsnip * 1.0.0 2022-06-16 [1] CRAN (R 4.2.0) #> pillar 1.8.1 2022-08-19 [1] CRAN (R 4.2.0) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.2.0) #> prodlim 2019.11.13 2019-11-17 [1] CRAN (R 4.2.0) #> purrr * 0.3.4 2020-04-17 [1] CRAN (R 4.2.0) #> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.2.0) #> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.2.0) #> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.2.0) #> R.utils 2.12.2 2022-11-11 [1] CRAN (R 4.2.0) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.2.0) #> Rcpp 1.0.9 2022-07-08 [1] CRAN (R 4.2.0) #> readr * 2.1.3 2022-10-01 [1] CRAN (R 4.2.0) #> readxl 1.4.1 2022-08-17 [1] CRAN (R 4.2.0) #> recipes * 1.0.1 2022-07-07 [1] CRAN (R 4.2.0) #> reprex 2.0.2 2022-08-17 [1] CRAN (R 4.2.0) #> rlang 1.0.6 2022-09-24 [1] CRAN (R 4.2.0) #> rmarkdown 2.14 2022-04-25 [1] CRAN (R 4.2.0) #> rpart 4.1.16 2022-01-24 [1] CRAN (R 4.2.0) #> rsample * 1.0.0 2022-06-24 [1] CRAN (R 4.2.0) #> rstudioapi 0.13 2020-11-12 [1] CRAN (R 4.2.0) #> rvest 1.0.2 2021-10-16 [1] CRAN (R 4.2.0) #> scales * 1.2.1 2022-08-20 [1] CRAN (R 4.2.0) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.2.0) #> shape 1.4.6 2021-05-19 [1] CRAN (R 4.2.0) #> stringi 1.7.8 2022-07-11 [1] CRAN (R 4.2.0) #> stringr * 1.4.1 2022-08-20 [1] CRAN (R 4.2.0) #> styler 1.8.1 2022-11-07 [1] CRAN (R 4.2.0) #> survival 3.3-1 2022-03-03 [1] CRAN (R 4.2.0) #> tibble * 3.1.8 2022-07-22 [1] CRAN (R 4.2.0) #> tidymodels * 1.0.0 2022-07-13 [1] CRAN (R 4.2.0) #> tidyr * 1.2.1 2022-09-08 [1] CRAN (R 4.2.0) #> tidyselect 1.1.2 2022-02-21 [1] CRAN (R 4.2.0) #> tidyverse * 1.3.2 2022-07-18 [1] CRAN (R 4.2.0) #> timeDate 4021.104 2022-07-19 [1] CRAN (R 4.2.0) #> tune * 1.0.0 2022-07-07 [1] CRAN (R 4.2.0) #> tzdb 0.3.0 2022-03-28 [1] CRAN (R 4.2.0) #> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.2.0) #> vctrs 0.4.2 2022-09-29 [1] CRAN (R 4.2.0) #> withr 2.5.0 2022-03-03 [1] CRAN (R 4.2.0) #> workflows * 1.0.0 2022-07-05 [1] CRAN (R 4.2.0) #> workflowsets * 1.0.0 2022-07-12 [1] CRAN (R 4.2.0) #> xfun 0.33 2022-09-12 [1] CRAN (R 4.2.0) #> xml2 1.3.3 2021-11-30 [1] CRAN (R 4.2.0) #> yaml 2.3.5 2022-02-21 [1] CRAN (R 4.2.0) #> yardstick * 1.0.0 2022-06-06 [1] CRAN (R 4.2.0) #> #> [1] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library #> #> ────────────────────────────────────────────────────────────────────────────── ```
simonpcouch commented 10 months ago

Thank you for the issue! Just wanted to let you know this hasn't fallen off our radar. Related to https://github.com/tidymodels/tune/issues/28 and https://github.com/tidymodels/tune/issues/45.

marcozanotti commented 8 months ago

+1 Thank you @simonpcouch By the moment is there any way to solve it?