Closed mattansb closed 1 year ago
Thanks for the report! A few things:
collapse
1.9.0 and the Github version of insight
.Sepal.Length
, the L
should be capitalized.brms
functions…library(brms)
library(ggdist)
library(insight)
library(marginaleffects)
mod <- brm(Species ~ .,
data = iris,
family = categorical(),
backend = "cmdstanr", cores = 4
)
predictions(mod, newdata = datagrid(Sepal.Length = 4:5))
#
# Group Estimate 2.5 % 97.5 % Sepal.Length
# setosa 1.000e+00 0 1.000e+00 4
# setosa 1.000e+00 0 1.000e+00 5
# versicolor 0.000e+00 0 1.000e+00 4
# versicolor 3.188e-192 0 1.000e+00 5
# virginica 0.000e+00 0 7.773e-04 4
# virginica 3.246e-202 0 7.397e-05 5
#
# Prediction type: response
# Columns: rowid, type, group, estimate, conf.low, conf.high, Species, Sepal.Width, Petal.Length, Petal.Width, Sepal.Length
hist(posterior_epred(mod))
Here’s a “working” categorical model:
dat <- "https://vincentarelbundock.github.io/Rdatasets/csv/Ecdat/Heating.csv"
dat <- read.csv(dat)
mod <- brm(depvar ~ ic.gc + oc.gc,
data = dat,
backend = "cmdstanr",
family = categorical(link = "logit"))
I’m not exactly sure how curve_interval()
is supposed to work, but at least the output seems to be proper rvar
objects:
p <- avg_predictions(mod, by = "group") |> posterior_draws("rvar")
p
#
# Group Estimate 2.5 % 97.5 %
# ec 0.07114 0.05606 0.08925
# er 0.09292 0.07571 0.11266
# gc 0.63623 0.60533 0.66782
# gr 0.14278 0.12109 0.16736
# hp 0.05540 0.04199 0.07190
#
# Prediction type: response
# Columns: type, group, estimate, conf.low, conf.high, rvar
p$rvar |> str()
# rvar<1000,4>[5] 0.072 ± 0.0086 0.093 ± 0.0096 0.636 ± 0.0158 0.143 ± 0.0117 ...
# - dimnames(*)=List of 1
# ..$ : chr [1:5] "ec" "er" "gc" "gr" ...
sessionInfo()
# R version 4.2.2 Patched (2022-11-10 r83330)
# Platform: x86_64-pc-linux-gnu (64-bit)
# Running under: Ubuntu 22.04.1 LTS
#
# Matrix products: default
# BLAS: /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.10.0
# LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.10.0
#
# locale:
# [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
# [3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
# [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
# [7] LC_PAPER=en_US.UTF-8 LC_NAME=C
# [9] LC_ADDRESS=C LC_TELEPHONE=C
# [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
#
# attached base packages:
# [1] stats graphics grDevices utils datasets methods base
#
# other attached packages:
# [1] marginaleffects_0.8.1.9119 insight_0.18.8.13
# [3] ggdist_3.2.1 brms_2.18.0
# [5] Rcpp_1.0.10
#
# loaded via a namespace (and not attached):
# [1] nlme_3.1-161 matrixStats_0.63.0 fs_1.6.0
# [4] xts_0.12.2 httr_1.4.4 threejs_0.3.3
# [7] rstan_2.21.8 tensorA_0.36.2 tools_4.2.2
# [10] backports_1.4.1 utf8_1.2.2 R6_2.5.1
# [13] DT_0.27 DBI_1.1.3 colorspace_2.1-0
# [16] withr_2.5.0 tidyselect_1.2.0 gridExtra_2.3
# [19] prettyunits_1.1.1 processx_3.8.0 Brobdingnag_1.2-8
# [22] curl_5.0.0 compiler_4.2.2 cli_3.6.0
# [25] xml2_1.3.3 shinyjs_2.1.0 colourpicker_1.2.0
# [28] posterior_1.3.1 scales_1.2.1 dygraphs_1.1.1.6
# [31] checkmate_2.1.0 mvtnorm_1.1-3 callr_3.7.3
# [34] stringr_1.5.0 digest_0.6.31 StanHeaders_2.21.0-7
# [37] rmarkdown_2.20 base64enc_0.1-3 pkgconfig_2.0.3
# [40] htmltools_0.5.4 highr_0.10 collapse_1.9.2
# [43] fastmap_1.1.0 htmlwidgets_1.6.1 rlang_1.0.6
# [46] shiny_1.7.4 farver_2.1.1 generics_0.1.3
# [49] jsonlite_1.8.4 zoo_1.8-11 crosstalk_1.2.0
# [52] gtools_3.9.4 dplyr_1.0.10 distributional_0.3.1
# [55] inline_0.3.19 magrittr_2.0.3 loo_2.5.1
# [58] bayesplot_1.10.0 Matrix_1.5-3 munsell_0.5.0
# [61] fansi_1.0.4 abind_1.4-5 lifecycle_1.0.3
# [64] stringi_1.7.12 yaml_2.3.7 pkgbuild_1.4.0
# [67] plyr_1.8.8 grid_4.2.2 parallel_4.2.2
# [70] promises_1.2.0.1 crayon_1.5.2 miniUI_0.1.1.1
# [73] lattice_0.20-45 knitr_1.42 ps_1.7.2
# [76] pillar_1.8.1 igraph_1.3.5 markdown_1.4
# [79] shinystan_2.6.0 reshape2_1.4.4 codetools_0.2-18
# [82] stats4_4.2.2 rstantools_2.2.0 reprex_2.0.2
# [85] glue_1.6.2 evaluate_0.20 data.table_1.14.6
# [88] RcppParallel_5.1.6 vctrs_0.5.2 httpuv_1.6.8
# [91] gtable_0.3.1 assertthat_0.2.1 ggplot2_3.4.0
# [94] xfun_0.36 mime_0.12 xtable_1.8-4
# [97] coda_0.19-4 later_1.3.0 tibble_3.1.8
# [100] shinythemes_1.2.0 cmdstanr_0.5.3 ellipsis_0.3.2
# [103] bridgesampling_1.1-2
Thanks!
The model is garbage - this is just a toy example (:
library(brms)
library(marginaleffects)
library(ggdist)
library(ggplot2)
mod <- brm(Species ~ .,
data = iris,
family = categorical(),
backend = "cmdstanr", cores = 4, iter = 300
)
predictions(mod, newdata = datagrid(Sepal.Width = seq(2, 4.5, len = 50))) |>
posteriordraws(shape = "rvar") |>
curve_interval(rvar, .along = c("Sepal.Width", "group")) |>
ggplot(aes(Sepal.Width, estimate, color = group)) +
facet_grid(~group) +
geom_ribbon(aes(ymin = conf.low, ymax = conf.high, fill = group),
color = NA, alpha = 0.4) +
geom_line()
Trying to run the example from https://github.com/vincentarelbundock/marginaleffects/issues/539#issuecomment-1317013858, I get: