tidymodels / poissonreg

parsnip wrappers for Poisson regression
https://poissonreg.tidymodels.org
Other
22 stars 4 forks source link

glmnet: Default to mean count for `multi_predict()` #63

Open hfrick opened 1 year ago

hfrick commented 1 year ago

Both predict() and multi_predict() have a default of type = NULL. For poisson_reg(engine = "glmnet"), predict() returns the mean count (glmnet's type "response") while multi_predict() returns the linear predictor (glmnet's type "link").

I would expect both to default to the same, and I would expect that to be the mean count.

Side note: we introduced a prediction type "linear_pred" for censored regression, which we might extend to other modes.

library(poissonreg)
#> Loading required package: parsnip

data(seniors, package = "poissonreg")

spec <- poisson_reg(penalty = 0.1, mixture = 0.3) %>%
  set_engine("glmnet", nlambda = 15)
f_fit <- fit(spec, count ~ ., data = seniors)

# default: type = NULL -> "numeric"
our_pred <- predict(f_fit, seniors, penalty = 0.1)
our_pred
#> # A tibble: 8 × 1
#>   .pred
#>   <dbl>
#> 1 540. 
#> 2 740. 
#> 3 282. 
#> 4 387. 
#> 5  90.7
#> 6 124. 
#> 7  47.4
#> 8  65.0
identical(our_pred, predict(f_fit, seniors, penalty = 0.1, type = "numeric"))
#> [1] TRUE

# default: type = NULL
our_multi_pred <- multi_predict(f_fit, seniors, penalty = 0.1) %>% 
  tidyr::unnest(cols = .pred)
our_multi_pred
#> # A tibble: 8 × 2
#>   penalty .pred
#>     <dbl> <dbl>
#> 1     0.1  6.29
#> 2     0.1  6.61
#> 3     0.1  5.64
#> 4     0.1  5.96
#> 5     0.1  4.51
#> 6     0.1  4.82
#> 7     0.1  3.86
#> 8     0.1  4.17

seniors_x <- model.matrix(~ ., data = seniors[, -4])[, -1]
seniors_y <- seniors$count
glmnet_fit <- glmnet::glmnet(x = seniors_x, y = seniors_y, family = "poisson",
                          alpha = 0.3, nlambda = 15)
waldo::compare(
  our_pred$.pred,
  predict(glmnet_fit, seniors_x, s = 0.1, type = "response") %>% as.vector()
)
#> ✔ No differences

waldo::compare(
  our_multi_pred$.pred,
  predict(glmnet_fit, seniors_x, s = 0.1, type = "link") %>% as.vector()
)
#> ✔ No differences

Created on 2023-01-21 with reprex v2.0.2