multi_predict fails for glmnet models when generating predictions for only a single row #48

Closed 1 year ago

joranE commented:

multi_predict() fails on glmnet models when passing only a single row of data.

# This works:
> m <- poisson_reg() %>% 
+   set_engine("glmnet") %>% 
+   fit(count ~ (.)^2, data = seniors[,2:4])
> multi_predict(m,new_data = seniors[1:2,])
# A tibble: 2 × 1
1 <tibble [72 × 2]>
2 <tibble [72 × 2]>


# This does not:
> m <- poisson_reg() %>% 
+   set_engine("glmnet") %>% 
+   fit(count ~ (.)^2, data = seniors[,2:4])
> multi_predict(m,new_data = seniors[1,])
Error in predict.glmnet(object = object$fit, newx = as.matrix(new_data[,  : 
  The number of variables in newx must be 3

hfrick commented:

Thanks for the report, @joranE ! That's a bug, the dimensions get dropped and incorrectly reconstructed here:


hfrick commented:

This also affects predictions for a single penalty value

#> Loading required package: parsnip

m <- poisson_reg() %>% 
  set_engine("glmnet") %>% 
  fit(count ~ (.)^2, data = seniors[,2:4])

multi_predict(m, new_data = seniors[1:2,])
#> # A tibble: 2 × 1
#>   .pred            
#>   <list>           
#> 1 <tibble [72 × 2]>
#> 2 <tibble [72 × 2]>
multi_predict(m, new_data = seniors[1,])
#> Error in predict.glmnet(object = object$fit, newx = as.matrix(new_data[, : The number of variables in newx must be 3

predict(m, new_data = seniors[1:2,], penalty = 0.1)
#> # A tibble: 2 × 1
#>   .pred
#>   <dbl>
#> 1  724.
#> 2  724.
predict(m, new_data = seniors[1,], penalty = 0.1)
#> Error in predict.glmnet(object = object$fit, newx = as.matrix(new_data[, : The number of variables in newx must be 3

Created on 2022-12-19 with reprex v2.0.2

github-actions[bot] commented:

