vincentarelbundock / marginaleffects

R package to compute and plot predictions, slopes, marginal means, and comparisons (contrasts, risk ratios, odds, etc.) for over 100 classes of statistical and ML models. Conduct linear and non-linear hypothesis tests, or equivalence tests. Calculate uncertainty estimates using the delta method, bootstrapping, or simulation-based inference
https://marginaleffects.com
Other
466 stars 47 forks source link

Error in `predictions()` with `brms::categorical()` #608

Closed mattansb closed 1 year ago

mattansb commented 1 year ago

Trying to run the example from https://github.com/vincentarelbundock/marginaleffects/issues/539#issuecomment-1317013858, I get:

library(brms)
library(marginaleffects)

mod <- brm(Species ~ ., 
  data = iris,
  family = categorical(), 
  backend = "cmdstanr", cores = 4
)

predictions(mod, newdata = datagrid(Sepal.length = 4:5))
#> Error: Unable to compute predicted values with this model. You can try to supply a different dataset to the
#>   `newdata` argument. If this does not work, you can file a report on the Github Issue Tracker:
#>   https://github.com/vincentarelbundock/marginaleffects/issues
#>   
#>   This error was also raised:  could not find function "assert_dependency"
#> In addition: Warning message:
#> Some of the variable names are missing from the model data: Sepal.length 
sessionInfo()
#> R version 4.2.2 (2022-10-31 ucrt)
#> Platform: x86_64-w64-mingw32/x64 (64-bit)
#> Running under: Windows 10 x64 (build 22621)
#> 
#> Matrix products: default
#> 
#> locale:
#> [1] LC_COLLATE=English_Israel.utf8  LC_CTYPE=English_Israel.utf8    LC_MONETARY=English_Israel.utf8
#> [4] LC_NUMERIC=C                    LC_TIME=English_Israel.utf8    
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] ggdist_3.2.1.9000          marginaleffects_0.8.1.9119 brms_2.18.0                Rcpp_1.0.9                
#> 
#> loaded via a namespace (and not attached):
#>   [1] nlme_3.1-160         matrixStats_0.63.0   xts_0.12.2           insight_0.18.8       threejs_0.3.3       
#>   [6] rstan_2.26.13        tensorA_0.36.2       tools_4.2.2          backports_1.4.1      utf8_1.2.2          
#>  [11] R6_2.5.1             DT_0.27              DBI_1.1.3            colorspace_2.1-0     withr_2.5.0         
#>  [16] tidyselect_1.2.0     gridExtra_2.3        prettyunits_1.1.1    processx_3.8.0       Brobdingnag_1.2-9   
#>  [21] emmeans_1.8.4-1      curl_5.0.0           compiler_4.2.2       cli_3.6.0            shinyjs_2.1.0       
#>  [26] sandwich_3.0-2       colourpicker_1.2.0   posterior_1.3.1      scales_1.2.1         dygraphs_1.1.1.6    
#>  [31] checkmate_2.1.0      mvtnorm_1.1-3        callr_3.7.3          stringr_1.5.0        digest_0.6.31       
#>  [36] StanHeaders_2.26.13  base64enc_0.1-3      pkgconfig_2.0.3      htmltools_0.5.4      collapse_1.9.0      
#>  [41] fastmap_1.1.0        htmlwidgets_1.6.1    rlang_1.0.6          rstudioapi_0.14      shiny_1.7.4         
#>  [46] farver_2.1.1         generics_0.1.3       zoo_1.8-11           jsonlite_1.8.4       crosstalk_1.2.0     
#>  [51] gtools_3.9.4         dplyr_1.0.10         distributional_0.3.1 inline_0.3.19        magrittr_2.0.3      
#>  [56] loo_2.5.1            bayesplot_1.10.0     Matrix_1.5-3         munsell_0.5.0        fansi_1.0.3         
#>  [61] clipr_0.8.0          abind_1.4-5          lifecycle_1.0.3      stringi_1.7.12       multcomp_1.4-20     
#>  [66] MASS_7.3-58.1        pkgbuild_1.4.0       plyr_1.8.8           grid_4.2.2           parallel_4.2.2      
#>  [71] promises_1.2.0.1     crayon_1.5.2         miniUI_0.1.1.1       lattice_0.20-45      splines_4.2.2       
#>  [76] knitr_1.41           ps_1.7.2             pillar_1.8.1         igraph_1.3.5         markdown_1.4        
#>  [81] estimability_1.4.1   shinystan_2.6.0      reshape2_1.4.4       codetools_0.2-18     stats4_4.2.2        
#>  [86] rstantools_2.2.0     MSBMisc_0.0.1.14     glue_1.6.2           V8_4.2.2             data.table_1.14.6   
#>  [91] RcppParallel_5.1.6   vctrs_0.5.1          httpuv_1.6.8         gtable_0.3.1         assertthat_0.2.1    
#>  [96] ggplot2_3.4.0        xfun_0.36            mime_0.12            xtable_1.8-4         coda_0.19-4         
#> [101] later_1.3.0          survival_3.4-0       tibble_3.1.8         shinythemes_1.2.0    cmdstanr_0.5.3      
#> [106] TH.data_1.1-1        ellipsis_0.3.2       bridgesampling_1.1-2
vincentarelbundock commented 1 year ago

Thanks for the report! A few things:

  1. I’m in the middle of a major refactor and pushed an ill-advised commit to Github yesterday night. Things should be fixed now. Sorry!!
  2. Please install collapse 1.9.0 and the Github version of insight.
  3. In Sepal.Length, the L should be capitalized.
  4. Are you sure that this model is working as intended? On my computer, the posterior predictions look very weird when I just call basic brms functions…
library(brms)
library(ggdist)
library(insight)
library(marginaleffects)

mod <- brm(Species ~ ., 
  data = iris,
  family = categorical(), 
  backend = "cmdstanr", cores = 4
)
predictions(mod, newdata = datagrid(Sepal.Length = 4:5))
# 
#       Group   Estimate 2.5 %    97.5 % Sepal.Length
#      setosa  1.000e+00     0 1.000e+00            4
#      setosa  1.000e+00     0 1.000e+00            5
#  versicolor  0.000e+00     0 1.000e+00            4
#  versicolor 3.188e-192     0 1.000e+00            5
#   virginica  0.000e+00     0 7.773e-04            4
#   virginica 3.246e-202     0 7.397e-05            5
# 
# Prediction type:  response 
# Columns: rowid, type, group, estimate, conf.low, conf.high, Species, Sepal.Width, Petal.Length, Petal.Width, Sepal.Length

hist(posterior_epred(mod))

Here’s a “working” categorical model:

dat <- "https://vincentarelbundock.github.io/Rdatasets/csv/Ecdat/Heating.csv"
dat <- read.csv(dat)
mod <- brm(depvar ~ ic.gc + oc.gc,
           data = dat,
           backend = "cmdstanr",
           family = categorical(link = "logit"))

I’m not exactly sure how curve_interval() is supposed to work, but at least the output seems to be proper rvar objects:

p <- avg_predictions(mod, by = "group") |> posterior_draws("rvar")

p
# 
#  Group Estimate   2.5 %  97.5 %
#     ec  0.07114 0.05606 0.08925
#     er  0.09292 0.07571 0.11266
#     gc  0.63623 0.60533 0.66782
#     gr  0.14278 0.12109 0.16736
#     hp  0.05540 0.04199 0.07190
# 
# Prediction type:  response 
# Columns: type, group, estimate, conf.low, conf.high, rvar

p$rvar |> str()
#  rvar<1000,4>[5]  0.072 ± 0.0086  0.093 ± 0.0096  0.636 ± 0.0158  0.143 ± 0.0117 ...
#  - dimnames(*)=List of 1
#   ..$ : chr [1:5] "ec" "er" "gc" "gr" ...
sessionInfo()
# R version 4.2.2 Patched (2022-11-10 r83330)
# Platform: x86_64-pc-linux-gnu (64-bit)
# Running under: Ubuntu 22.04.1 LTS
# 
# Matrix products: default
# BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.10.0
# LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.10.0
# 
# locale:
#  [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
#  [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
#  [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
#  [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
#  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
# [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       
# 
# attached base packages:
# [1] stats     graphics  grDevices utils     datasets  methods   base     
# 
# other attached packages:
# [1] marginaleffects_0.8.1.9119 insight_0.18.8.13         
# [3] ggdist_3.2.1               brms_2.18.0               
# [5] Rcpp_1.0.10               
# 
# loaded via a namespace (and not attached):
#   [1] nlme_3.1-161         matrixStats_0.63.0   fs_1.6.0            
#   [4] xts_0.12.2           httr_1.4.4           threejs_0.3.3       
#   [7] rstan_2.21.8         tensorA_0.36.2       tools_4.2.2         
#  [10] backports_1.4.1      utf8_1.2.2           R6_2.5.1            
#  [13] DT_0.27              DBI_1.1.3            colorspace_2.1-0    
#  [16] withr_2.5.0          tidyselect_1.2.0     gridExtra_2.3       
#  [19] prettyunits_1.1.1    processx_3.8.0       Brobdingnag_1.2-8   
#  [22] curl_5.0.0           compiler_4.2.2       cli_3.6.0           
#  [25] xml2_1.3.3           shinyjs_2.1.0        colourpicker_1.2.0  
#  [28] posterior_1.3.1      scales_1.2.1         dygraphs_1.1.1.6    
#  [31] checkmate_2.1.0      mvtnorm_1.1-3        callr_3.7.3         
#  [34] stringr_1.5.0        digest_0.6.31        StanHeaders_2.21.0-7
#  [37] rmarkdown_2.20       base64enc_0.1-3      pkgconfig_2.0.3     
#  [40] htmltools_0.5.4      highr_0.10           collapse_1.9.2      
#  [43] fastmap_1.1.0        htmlwidgets_1.6.1    rlang_1.0.6         
#  [46] shiny_1.7.4          farver_2.1.1         generics_0.1.3      
#  [49] jsonlite_1.8.4       zoo_1.8-11           crosstalk_1.2.0     
#  [52] gtools_3.9.4         dplyr_1.0.10         distributional_0.3.1
#  [55] inline_0.3.19        magrittr_2.0.3       loo_2.5.1           
#  [58] bayesplot_1.10.0     Matrix_1.5-3         munsell_0.5.0       
#  [61] fansi_1.0.4          abind_1.4-5          lifecycle_1.0.3     
#  [64] stringi_1.7.12       yaml_2.3.7           pkgbuild_1.4.0      
#  [67] plyr_1.8.8           grid_4.2.2           parallel_4.2.2      
#  [70] promises_1.2.0.1     crayon_1.5.2         miniUI_0.1.1.1      
#  [73] lattice_0.20-45      knitr_1.42           ps_1.7.2            
#  [76] pillar_1.8.1         igraph_1.3.5         markdown_1.4        
#  [79] shinystan_2.6.0      reshape2_1.4.4       codetools_0.2-18    
#  [82] stats4_4.2.2         rstantools_2.2.0     reprex_2.0.2        
#  [85] glue_1.6.2           evaluate_0.20        data.table_1.14.6   
#  [88] RcppParallel_5.1.6   vctrs_0.5.2          httpuv_1.6.8        
#  [91] gtable_0.3.1         assertthat_0.2.1     ggplot2_3.4.0       
#  [94] xfun_0.36            mime_0.12            xtable_1.8-4        
#  [97] coda_0.19-4          later_1.3.0          tibble_3.1.8        
# [100] shinythemes_1.2.0    cmdstanr_0.5.3       ellipsis_0.3.2      
# [103] bridgesampling_1.1-2
mattansb commented 1 year ago

Thanks!

The model is garbage - this is just a toy example (:

library(brms)
library(marginaleffects)
library(ggdist)
library(ggplot2)

mod <- brm(Species ~ ., 
  data = iris,
  family = categorical(), 
  backend = "cmdstanr", cores = 4, iter = 300
)

predictions(mod, newdata = datagrid(Sepal.Width = seq(2, 4.5, len = 50))) |> 
  posteriordraws(shape = "rvar") |> 
  curve_interval(rvar, .along = c("Sepal.Width", "group")) |> 
  ggplot(aes(Sepal.Width, estimate, color = group)) + 
  facet_grid(~group) + 
  geom_ribbon(aes(ymin = conf.low, ymax = conf.high, fill = group), 
              color = NA, alpha = 0.4) + 
  geom_line()

image