tidymodels / dials

Tools for creating tuning parameter values
https://dials.tidymodels.org/
Other
111 stars 26 forks source link

Range of penalty() in Ridge Regression does not cover lambda chosen by cv.glmnet #140

Closed andrjohns closed 4 years ago

andrjohns commented 4 years ago

Hi,

I'm trying to use the tidymodels ecosystem to run a ridge regression, and so am using the grid_regular function to create the search space of lambda values to evaluate. However, these lambda values are always upper-bounded by 1, and when I run some data through cv.glmnet, the lambdas tested and chosen are greater than one:

library(tidymodels)
#> -- Attaching packages ----------------------------------------------------------------- tidymodels 0.1.1 --
#> v broom     0.7.0      v recipes   0.1.13
#> v dials     0.0.8      v rsample   0.0.7 
#> v dplyr     1.0.2      v tibble    3.0.3 
#> v ggplot2   3.3.2      v tidyr     1.1.2 
#> v infer     0.5.3      v tune      0.1.1 
#> v modeldata 0.0.2      v workflows 0.1.3 
#> v parsnip   0.1.3      v yardstick 0.0.7 
#> v purrr     0.3.4
#> -- Conflicts -------------------------------------------------------------------- tidymodels_conflicts() --
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter()  masks stats::filter()
#> x dplyr::lag()     masks stats::lag()
#> x recipes::step()  masks stats::step()
library(glmnet)
#> Loading required package: Matrix
#> 
#> Attaching package: 'Matrix'
#> The following objects are masked from 'package:tidyr':
#> 
#>     expand, pack, unpack
#> Loaded glmnet 4.0-2

dat = cbind(structure(rnorm(200*16),dim=c(200,16)),
            structure(sample(0:1,200*14,replace = T),dim=c(200,14))) %>%
    as.matrix()

mod = linear_reg(mode="regression",penalty=tune(),mixture=0) %>%
    set_engine("glmnet")

range(grid_regular(parameters(mod),levels=10))
#> [1] 1e-10 1e+00

fit1 = cv.glmnet(dat[,2:30],dat[,1],alpha=0)
range(fit1$lambda)
#> [1]   0.01971351 197.13506375
fit1$lambda.min
#> [1] 13.27537

Created on 2020-09-08 by the reprex package (v0.3.0)

Should I be using this functionality differently for ridge regression?

topepo commented 4 years ago

The range of lambda is data-driven and is affected by the choice of alpha. For that reason (and a few others), we use a default range of:

> penalty()
Amount of Regularization (quantitative)
Transformer:  log-10 
Range (transformed scale): [-10, 0]

which works well in 99% of the cases.

You can change the range if you need to go higher:

library(tidymodels)
#> ── Attaching packages ────────────────────────────────────────── tidymodels 0.1.1 ──
#> ✓ broom     0.7.0          ✓ recipes   0.1.13    
#> ✓ dials     0.0.8.9000     ✓ rsample   0.0.7     
#> ✓ dplyr     1.0.2          ✓ tibble    3.0.3     
#> ✓ ggplot2   3.3.2          ✓ tidyr     1.1.2     
#> ✓ infer     0.5.2          ✓ tune      0.1.1.9000
#> ✓ modeldata 0.0.2          ✓ workflows 0.1.3     
#> ✓ parsnip   0.1.3.9000     ✓ yardstick 0.0.7     
#> ✓ purrr     0.3.4
#> ── Conflicts ───────────────────────────────────────────── tidymodels_conflicts() ──
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter()  masks stats::filter()
#> x dplyr::lag()     masks stats::lag()
#> x recipes::step()  masks stats::step()

mod  <- 
  linear_reg(mode = "regression", penalty = tune(), mixture = 0) %>%
  set_engine("glmnet")

mod %>% 
  parameters() %>% 
  update(penalty = penalty(log10(c(0.02, 198)))) %>% 
  grid_regular(levels = 10) %>% 
  summary()
#>     penalty        
#>  Min.   :  0.0200  
#>  1st Qu.:  0.2232  
#>  Median :  2.2556  
#>  Mean   : 30.9259  
#>  3rd Qu.: 21.5277  
#>  Max.   :198.0000

Created on 2020-09-08 by the reprex package (v0.3.0)

(this also works with the CRAN versions of these packages; the versions above are what I have currently installed)

andrjohns commented 4 years ago

Excellent, thanks for the help Max!

github-actions[bot] commented 3 years ago

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.