tidymodels / multilevelmod

Parsnip wrappers for mixed-level and hierarchical models
https://multilevelmod.tidymodels.org/
Other
74 stars 3 forks source link

Incorrect predictions from fit_resamples() when applied to lmer model #44

Open a-difabio opened 2 years ago

a-difabio commented 2 years ago

I am having trouble using the tune::fit_resamples() function on a lmer model (from the multilevelmod package). In particular, it looks like that when the predictions for the assessment set are calculated, the model doesn't properly account for all the possible combinations of grouping levels.

I have included a reprex in which I show that the results of a predict() call on a lmer object are different than the predictions obtained from a fit_resamples() call (using collect_predictions()).

library(tidyverse)
library(tidymodels)
library(multilevelmod)

data(mpg, package = "ggplot2")

set.seed(123)

lmer_model = linear_reg() %>% 
  set_engine("lmer")

lmer_workflow = workflow() %>% 
  add_variables(outcomes = cty,
                predictors = c(year, manufacturer, model)) %>% 
  add_model(lmer_model, formula = cty ~ year + (1|manufacturer/model))

mpg_split = mpg %>% validation_split(prop = 3/4)

analysis = mpg_split$splits[[1]] %>% analysis()
assessment = mpg_split$splits[[1]] %>% assessment()

# using predict() on the assessment dataset works as expected
predicted_via_workflow = lmer_workflow %>%
  fit(analysis) %>%
  extract_fit_engine() %>%
  predict(assessment) %>%
  plot()

# the predictions from the fit_resamples() function do not vary per group
predicted_via_tune = lmer_workflow %>% 
  fit_resamples(mpg_split, control = control_resamples(allow_par = FALSE,
                                                       save_pred = TRUE)) %>% 
  collect_predictions() %>%
  pluck(".pred") %>%
  plot()

Created on 2022-06-20 by the reprex package (v2.0.1)

juliasilge commented 2 years ago

Are you sure you have the latest version, which includes the fix #41 for #38? It went to CRAN on June 17.

a-difabio commented 2 years ago

I believe I am using the latest version of the package:

sessioninfo::package_info(pkgs = "multilevelmod", dependencies = FALSE)
#>  package       * version date (UTC) lib source
#>  multilevelmod   1.0.0   2022-06-17 [1] CRAN (R 4.2.0)

Created on 2022-06-21 by the reprex package (v2.0.1)

In fact, I think that before fix #41 this same code would have thrown an error without predicting anything, while now the fitted workflow can be used to predict new values.