mjskay / tidybayes

Bayesian analysis + tidy data + geoms (R package)
http://mjskay.github.io/tidybayes
GNU General Public License v3.0
712 stars 59 forks source link

posterior_predict() returning incorrect fitted values for rstanarm multivariate regression models #271

Closed rudeboybert closed 3 years ago

rudeboybert commented 3 years ago

Hello, thank you for a wonderful package. If only it existed while I was in grad school!

add_predicted_draws() seems to be returning incorrect fitted values for rstanarm::stan_mvmer() multivariate GLM models. I've narrowed down the issue to a missing m argument that's passed to rstanarm::posterior_predict() for multivariate stanmvreg models, but not for regular stanreg models.

I'm happy to take a stab at a PR, but first wanted to check in and see if you wanted to go down the stanmvreg rabbit hole, or if you'd prefer throwing a "we don't support these types of models" warning like you do here for "ulam", "quap", "map", "map2stan" models.

Here is a reprex:

Create 3-dim outcome multivariate example data set

library(tidyverse)
library(rstanarm)
library(tidybayes)

multivariate_data <- bind_rows(
  tibble(x = rep(1:50, times = 4)) %>% mutate(obs = "y1", y = 0 + 0.4*x),
  tibble(x = rep(1:50, times = 4)) %>% mutate(obs = "y2", y = 30 + 0.4*x),
  tibble(x = rep(1:50, times = 4)) %>% mutate(obs = "y3", y = 60 + 0.4*x)
) %>%
  mutate(
    group = rep(1:4, each = 50) %>% rep(times = 3),
    y = y + rnorm(n())
  )

ggplot(multivariate_data, aes(x = x, y = y, col = obs)) +
  geom_point() +
  geom_smooth(method = "lm", se = FALSE)

Note the group means:

multivariate_data %>%
  group_by(obs) %>%
  summarize(mean_y = mean(y))
#> `summarise()` ungrouping output (override with `.groups` argument)
#> # A tibble: 3 x 2
#>   obs   mean_y
#>   <chr>  <dbl>
#> 1 y1      10.2
#> 2 y2      40.1
#> 3 y3      70.3

Fit multivariate regression model & get (incorrect) posterior fitted values

# Convert data to wide format
multivariate_data_wide <- multivariate_data %>%
  pivot_wider(names_from = obs, values_from = y)

# Fit model
stanmvreg_model <- stan_mvmer(
  formula = list(
    y1 ~ x + (1|group),
    y2 ~ x + (1|group),
    y3 ~ x + (1|group)
  ),
  data = multivariate_data_wide,
  seed = 76,
  chains = 1,
  iter = 2000
)

# Get posterior means
multivariate_data %>%
  add_predicted_draws(stanmvreg_model) %>%
  group_by(obs) %>%
  summarize(mean_y = mean(y), mean_y_hat = mean(.prediction))
#> `summarise()` ungrouping output (override with `.groups` argument)
#> # A tibble: 3 x 3
#>   obs   mean_y mean_y_hat
#>   <chr>  <dbl>      <dbl>
#> 1 y1      10.2       10.2
#> 2 y2      40.1       10.2
#> 3 y3      70.3       10.2

As you can see, the posterior means are off for y2 and y3.

Posterior fitted values using rstanarm package

Using the root rstanarm::posterior_predict() function that's being wrapped by add_predicted_draws(), it seems only the correct posterior mean for y1 is being returned again.

stanmvreg_model %>%
  posterior_predict() %>%
  apply(1, mean) %>%
  mean()
#> [1] 10.20104

However, looking at help file ?rstanarm::posterior_predict -> Usage -> there is an extra argument "m" needed for models of class stanmvreg that defaults to m = 1. If you specify m = 1, 2, 3, for the (y1, y2, y3) multivariate outcome we have, we get the correct posterior means:

stanmvreg_model %>%
  posterior_predict(m = 1) %>%
  apply(1, mean) %>%
  mean()
#> [1] 10.20179
stanmvreg_model %>%
  posterior_predict(m = 2) %>%
  apply(1, mean) %>%
  mean()
#> [1] 40.1348
stanmvreg_model %>%
  posterior_predict(m = 3) %>%
  apply(1, mean) %>%
  mean()
#> [1] 70.29842

Attempting to pass a m argument to add_predicted_draws() does throw an error, but for a model of type "numeric"

multivariate_data %>%
  add_predicted_draws(stanmvreg_model, m = 2)
#> Error in predicted_draws.default(model, newdata, prediction, ..., n = n, : Models of type "numeric" are not currently supported by `predicted_draws`.
#> You might try using `add_draws()` for models that do not have explicit fit/prediction
#> support; see help("add_draws") for an example. See also help("tidybayes-models") for
#> more information on what functions are supported by what model types.

sessionInfo()

R version 4.0.1 (2020-06-06)
Platform: x86_64-apple-darwin17.0 (64-bit)
Running under: macOS Catalina 10.15.6

Matrix products: default
BLAS:   /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] tidybayes_2.1.1.9000 rstanarm_2.21.1      Rcpp_1.0.5           forcats_0.5.0       
 [5] stringr_1.4.0        dplyr_1.0.2          purrr_0.3.4          readr_1.3.1         
 [9] tidyr_1.1.2          tibble_3.0.3         ggplot2_3.3.2        tidyverse_1.3.0     

loaded via a namespace (and not attached):
  [1] minqa_1.2.4          colorspace_1.4-1     ellipsis_0.3.1       ggridges_0.5.2      
  [5] rsconnect_0.8.16     markdown_1.1         base64enc_0.1-3      fs_1.5.0            
  [9] rstudioapi_0.11      farver_2.0.3         rstan_2.21.2         svUnit_1.0.3        
 [13] DT_0.15              fansi_0.4.1          lubridate_1.7.9      xml2_1.3.2          
 [17] splines_4.0.1        codetools_0.2-16     knitr_1.30           shinythemes_1.1.2   
 [21] bayesplot_1.7.2      jsonlite_1.7.1       nloptr_1.2.2.2       packrat_0.5.0       
 [25] broom_0.7.0          dbplyr_1.4.4         ggdist_2.2.0         shiny_1.5.0         
 [29] clipr_0.7.0          compiler_4.0.1       httr_1.4.2           backports_1.1.10    
 [33] assertthat_0.2.1     Matrix_1.2-18        fastmap_1.0.1        cli_2.0.2           
 [37] later_1.1.0.1        htmltools_0.5.0      prettyunits_1.1.1    tools_4.0.1         
 [41] igraph_1.2.5         coda_0.19-3          gtable_0.3.0         glue_1.4.2          
 [45] reshape2_1.4.4       V8_3.2.0             cellranger_1.1.0     vctrs_0.3.4         
 [49] nlme_3.1-149         crosstalk_1.1.0.1    xfun_0.17            ps_1.3.4            
 [53] lme4_1.1-23          rvest_0.3.6          mime_0.9             miniUI_0.1.1.1      
 [57] lifecycle_0.2.0      gtools_3.8.2         statmod_1.4.34       MASS_7.3-53         
 [61] zoo_1.8-8            scales_1.1.1         colourpicker_1.1.0   hms_0.5.3           
 [65] promises_1.1.1       parallel_4.0.1       inline_0.3.16        shinystan_2.5.0     
 [69] yaml_2.2.1           curl_4.3             gridExtra_2.3        loo_2.3.1           
 [73] StanHeaders_2.21.0-6 stringi_1.5.3        dygraphs_1.1.1.6     boot_1.3-25         
 [77] pkgbuild_1.1.0       rlang_0.4.7          pkgconfig_2.0.3      matrixStats_0.57.0  
 [81] distributional_0.2.0 evaluate_0.14        lattice_0.20-41      labeling_0.3        
 [85] rstantools_2.1.1     htmlwidgets_1.5.1    tidyselect_1.1.0     processx_3.4.4      
 [89] plyr_1.8.6           magrittr_1.5         R6_2.4.1             generics_0.0.2      
 [93] DBI_1.1.0            whisker_0.4          mgcv_1.8-33          pillar_1.4.6        
 [97] haven_2.3.1          withr_2.3.0          xts_0.12.1           survival_3.2-7      
[101] modelr_0.1.8         crayon_1.3.4         arrayhelpers_1.1-0   utf8_1.1.4          
[105] rmarkdown_2.3        grid_4.0.1           readxl_1.3.1         blob_1.2.1          
[109] callr_3.4.4          threejs_0.3.3        reprex_0.3.0         digest_0.6.25       
[113] xtable_1.8-4         httpuv_1.5.4         RcppParallel_5.0.2   stats4_4.0.1        
[117] munsell_0.5.0        shinyjs_2.0.0     
mjskay commented 3 years ago

Thanks for this! Yes, I'd definitely like to support this model type, and a PR would be welcome.

The ideal solution I think would be to pattern the solution after how tidybayes handles multivariate models for brms, using the (slightly poorly named, for historical reasons) category argument. In the best case scenario, if it is possible to determine the names of the different y variables from the model object itself, the default approach should be to generate predictions from all of the response variables and return a tidy format dataframe with all response variables and a .category column indicating which y variable each row comes from. Does that make sense?

rudeboybert commented 3 years ago

Yes, your comment makes sense and thanks for the pointers. I'll take a stab at it in the next few days using brms as a template.

mjskay commented 3 years ago

Sweet, thanks! It would be much appreciated. Let me know if you need any pointers about the internals, the fitted/predicted_draws stuff for brms is particularly hairy.

mjskay commented 3 years ago

Note to self: pretty sure the m parameter doesn't work because of partial argument matching. Should be able to fix this in the next round of iterations for predicted_draws

mjskay commented 3 years ago

This should now be fixed in the github version: you can now pass the m parameter through properly. So you can do something like this (note obs should not be in the prediction grid):

grid = multivariate_data %>%
  modelr::data_grid(x, group) 

preds = bind_rows(
  add_predicted_draws(mutate(grid, obs = "y1"), stanmvreg_model, m = 1),
  add_predicted_draws(mutate(grid, obs = "y2"), stanmvreg_model, m = 2),
  add_predicted_draws(mutate(grid, obs = "y3"), stanmvreg_model, m = 3)
)

preds %>%
  median_qi()
# # A tibble: 600 x 10
#        x group obs    .row .prediction .lower .upper .width .point .interval
#    <int> <int> <chr> <int>       <dbl>  <dbl>  <dbl>  <dbl> <chr>  <chr>    
#  1     1     1 y1        1       0.616  -1.48   2.68   0.95 median qi       
#  2     1     1 y2        1      30.4    28.4   32.3    0.95 median qi       
#  3     1     1 y3        1      60.3    58.3   62.5    0.95 median qi       
#  4     1     2 y1        2       0.390  -1.61   2.60   0.95 median qi       
#  5     1     2 y2        2      30.4    28.4   32.3    0.95 median qi       
#  6     1     2 y3        2      60.4    58.4   62.2    0.95 median qi       
#  7     1     3 y1        3       0.452  -1.82   2.48   0.95 median qi       
#  8     1     3 y2        3      30.4    28.4   32.4    0.95 median qi       
#  9     1     3 y3        3      60.6    58.5   62.4    0.95 median qi       
# 10     1     4 y1        4       0.566  -1.53   2.64   0.95 median qi       
# # ... with 590 more rows

With the new rvar-based workflow that will be coming in the next version (using the rvar datatype from {posterior} when it hits CRAN, which should be soon), you will also be able to easily create rvar columns of predictions using add_predicted_rvars() instead of add_predicted_draws():

multivariate_data %>%
  modelr::data_grid(x, group) %>%
  add_predicted_rvars(stanmvreg_model, prediction = "y1", m = 1) %>%
  add_predicted_rvars(stanmvreg_model, prediction = "y2", m = 2) %>%
  add_predicted_rvars(stanmvreg_model, prediction = "y3", m = 3)
# # A tibble: 200 x 5
#        x group          y1         y2        y3
#    <int> <int>      <rvar>     <rvar>    <rvar>
#  1     1     1  0.57 ± 1.0  30 ± 1.00  60 ± 1.0
#  2     1     2  0.48 ± 1.1  30 ± 1.04  60 ± 1.0
#  3     1     3  0.47 ± 1.1  30 ± 1.03  61 ± 1.0
#  4     1     4  0.52 ± 1.1  30 ± 0.99  60 ± 1.0
#  5     2     1  0.98 ± 1.0  31 ± 1.03  61 ± 1.0
#  6     2     2  0.88 ± 1.1  31 ± 1.03  61 ± 1.0
#  7     2     3  0.87 ± 1.1  31 ± 1.02  61 ± 1.0
#  8     2     4  0.91 ± 1.1  31 ± 1.03  61 ± 1.0
#  9     3     1  1.33 ± 1.1  31 ± 1.02  61 ± 1.1
# 10     3     2  1.29 ± 1.1  31 ± 1.01  61 ± 1.0
# # ... with 190 more rows

For more on the rvar stuff you can check out vignette("rvar", package = "posterior") or vignette("tidy-posterior", package = "tidybayes") on the github versions of both packages.

rudeboybert commented 3 years ago

Awesome! Thanks for doing this!