mjskay / tidybayes

Bayesian analysis + tidy data + geoms (R package)
http://mjskay.github.io/tidybayes
GNU General Public License v3.0
727 stars 59 forks source link

for multinomial (categorical) models add_linpred_draws() doesn't return original outcome labels #310

Open tfjaeger opened 1 year ago

tfjaeger commented 1 year ago

Hi,

I'm using tidybayes to visualize the results of a (mixed-effects) categorical ('multinomial') regression fit with brms::brm(). Even when I apply recover_types to the model prior to using add_linpred_draws(), the output of add_linpred_draws() has only integer values in the .category column, rather than the original labels of the outcome.

For example:

# make some categorical data
d <-   
  rmultinom(n = 100, size = 1, prob = c(.1, .4, .5)) %>%
  as_tibble(.name_repair) %>%
  mutate(response = case_when(V1 == 1 ~ "A", V2 == 1 ~ "B", T ~ "C")) %>%
  select(response)

# fit categorical model
m <- brm(
  bf(response ~ 1),
  family = categorical,
  data = d,
  backend = "cmdstanr")

# apply (or not) recover types to model
m %<>% recover_types(d)

# get draws from linear predictor
d %>% 
  add_linpred_draws(
  object = m, 
  ndraws = 100,
  re_formula = NA,
  value = "logit",
  category = "predicted_response")

yields:

# A tibble: 20,000 × 7
# Groups:   response, .row, predicted_response [200]
   response  .row .chain .iteration .draw predicted_response logit
   <chr>    <int>  <int>      <int> <int> <fct>              <dbl>
 1 B            1     NA         NA     1 1                   1.22
 2 B            1     NA         NA     2 1                   1.09
 3 B            1     NA         NA     3 1                   1.35
 4 B            1     NA         NA     4 1                   2.01
 5 B            1     NA         NA     5 1                   1.79
 6 B            1     NA         NA     6 1                   2.10
 7 B            1     NA         NA     7 1                   1.69
 8 B            1     NA         NA     8 1                   1.57
 9 B            1     NA         NA     9 1                   2.09
10 B            1     NA         NA    10 1                   1.67
# ℹ 19,990 more rows
# ℹ Use `print(n = ...)` to see more rows

rather than:

# A tibble: 20,000 × 7
# Groups:   response, .row, predicted_response [200]
   response  .row .chain .iteration .draw predicted_response logit
   <chr>    <int>  <int>      <int> <int> <fct>              <dbl>
 1 B            1     NA         NA     1 B                   1.22
 2 B            1     NA         NA     2 B                   1.09
 3 B            1     NA         NA     3 B                   1.35
 4 B            1     NA         NA     4 B                   2.01
 5 B            1     NA         NA     5 B                   1.79
 6 B            1     NA         NA     6 B                   2.10
 7 B            1     NA         NA     7 B                   1.69
 8 B            1     NA         NA     8 B                   1.57
 9 B            1     NA         NA     9 B                   2.09
10 B            1     NA         NA    10 B                   1.67
# ℹ 19,990 more rows
# ℹ Use `print(n = ...)` to see more rows

Despite the fact that the original category labels are available in the model object:

 Family: categorical 
  Links: muB = logit; muC = logit 
Formula: response ~ 1 
   Data: d (Number of observations: 100) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Population-Level Effects: 
              Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
muB_Intercept     1.66      0.36     1.01     2.44 1.00     1043      929
muC_Intercept     1.51      0.36     0.86     2.28 1.00     1071      906

(I also note that the transform argument also doesn't work for this type of model though I imagine that issues lies with brms?).

Thank you for looking into this! (and apologies if it's a false alarm)

tfjaeger commented 1 year ago

The relevant categories are available in model$family$cats and the reference category is available in model$family$refcat.