mjskay / tidybayes

Bayesian analysis + tidy data + geoms (R package)
http://mjskay.github.io/tidybayes
GNU General Public License v3.0
725 stars 59 forks source link

Extract groups from intercept names with spread/gather_draws #321

Open JackCaster opened 5 months ago

JackCaster commented 5 months ago

I would like to extract draws from a model that has an intercept that varies with a predictor with spread/gather_draws.

When the model is specified as y ~ 0 + group it results in a brms model with variables "b_groupA" and "b_groupB". But it is not possible to extract the group using m %>% gather_draws(b_group[group]). What it is possible is to extract the intercepts using regex but not the group index variable:

m  %>%
    gather_draws(`b_.*`, regex = TRUE)
# Groups:   .variable [2]
   .chain .iteration .draw .variable .value
    <int>      <int> <int> <chr>      <dbl>
 1      1          1     1 b_groupA    104.
 2      1          2     2 b_groupA    104.
 3      1          3     3 b_groupA    104.
 4      1          4     4 b_groupA    104.
 5      1          5     5 b_groupA    106.
 6      1          6     6 b_groupA    105.

So far, the workaround for me has been to use separate_wider_regex but things get complex quite quickly.

Do you think your functions can be adapted to cover such scenario?

Code:

library(tidyverse)
library(brms)
library(tidybayes)

N <- list(a = 25, b = 30) # sample size
MEAN <- list(a = 105, b = 103) # population mean
SD <- list(a = 2, b = 5) # population sd

ya <- tibble(y = rnorm(N$a, mean = MEAN$a, sd = SD$a))
yb <- tibble(y = rnorm(N$b, mean = MEAN$b, sd = SD$b))

df <- bind_rows(list(A = ya, B = yb), .id = "group")

m  <- brm(y ~ 0 + group, data = df)

get_variables(m)
#  [1] "b_groupA"      "b_groupB"      "sigma"         "lprior"        "lp__"          "accept_stat__" "treedepth__"   "stepsize__"    "divergent__"   "n_leapfrog__"  "energy__"     

m  %>%
    gather_draws(`b_.*`, regex = TRUE)
mjskay commented 5 months ago

Hmm, at some point I should probably support more arbitrary indexing schemes in those functions. Just added an issue for that here: #322

Unfortunately I don't have cycles for it at the moment (though happy to take PRs if someone wanted to tackle it), but I can suggest using rename_with with gsub to standardize names before using spread_draws() / gather_draws().

e.g. consider this data set:

set.seed(1234)
df = data.frame(
  a = rnorm(1000), 
  b_groupA = rnorm(1000, 1), 
  b_groupB = rnorm(1000, 2)
) |>
  posterior::as_draws_df()
df
#> # A draws_df: 1000 iterations, 1 chains, and 3 variables
#>        a b_groupA b_groupB
#> 1  -1.21    -0.21     1.03
#> 2   0.28     1.30     1.90
#> 3   1.08    -0.54     1.89
#> 4  -2.35     1.64     3.19
#> 5   0.43     1.70     0.34
#> 6   0.51    -0.91     0.95
#> 7  -0.57     1.94     0.26
#> 8  -0.55     0.78     2.51
#> 9  -0.56     0.33     1.55
#> 10 -0.89     1.45     0.16
#> # ... with 990 more draws
#> # ... hidden reserved variables {'.chain', '.iteration', '.draw'}

You can rename the columns into the format tidybayes expects using something like this:

df |>
  dplyr::rename_with(\(x) gsub("b_group(.*)", "b_group[\\1]", x))
#> # A draws_df: 1000 iterations, 1 chains, and 3 variables
#>        a b_group[A] b_group[B]
#> 1  -1.21      -0.21       1.03
#> 2   0.28       1.30       1.90
#> 3   1.08      -0.54       1.89
#> 4  -2.35       1.64       3.19
#> 5   0.43       1.70       0.34
#> 6   0.51      -0.91       0.95
#> 7  -0.57       1.94       0.26
#> 8  -0.55       0.78       2.51
#> 9  -0.56       0.33       1.55
#> 10 -0.89       1.45       0.16
#> # ... with 990 more draws
#> # ... hidden reserved variables {'.chain', '.iteration', '.draw'}

Which can be chained right into spread_draws, a la:

df |>
  dplyr::rename_with(\(x) gsub("b_group(.*)", "b_group[\\1]", x)) |>
  tidybayes::spread_draws(b_group[i]) |> 
  ggdist::median_qi()
#> # A tibble: 2 × 7
#>   i     b_group .lower .upper .width .point .interval
#>   <chr>   <dbl>  <dbl>  <dbl>  <dbl> <chr>  <chr>    
#> 1 A        1.01 -0.973   2.84   0.95 median qi       
#> 2 B        2.06  0.101   4.02   0.95 median qi