Open JackCaster opened 5 months ago
Hmm, at some point I should probably support more arbitrary indexing schemes in those functions. Just added an issue for that here: #322
Unfortunately I don't have cycles for it at the moment (though happy to take PRs if someone wanted to tackle it), but I can suggest using rename_with
with gsub
to standardize names before using spread_draws()
/ gather_draws()
.
e.g. consider this data set:
set.seed(1234)
df = data.frame(
a = rnorm(1000),
b_groupA = rnorm(1000, 1),
b_groupB = rnorm(1000, 2)
) |>
posterior::as_draws_df()
df
#> # A draws_df: 1000 iterations, 1 chains, and 3 variables
#> a b_groupA b_groupB
#> 1 -1.21 -0.21 1.03
#> 2 0.28 1.30 1.90
#> 3 1.08 -0.54 1.89
#> 4 -2.35 1.64 3.19
#> 5 0.43 1.70 0.34
#> 6 0.51 -0.91 0.95
#> 7 -0.57 1.94 0.26
#> 8 -0.55 0.78 2.51
#> 9 -0.56 0.33 1.55
#> 10 -0.89 1.45 0.16
#> # ... with 990 more draws
#> # ... hidden reserved variables {'.chain', '.iteration', '.draw'}
You can rename the columns into the format tidybayes expects using something like this:
df |>
dplyr::rename_with(\(x) gsub("b_group(.*)", "b_group[\\1]", x))
#> # A draws_df: 1000 iterations, 1 chains, and 3 variables
#> a b_group[A] b_group[B]
#> 1 -1.21 -0.21 1.03
#> 2 0.28 1.30 1.90
#> 3 1.08 -0.54 1.89
#> 4 -2.35 1.64 3.19
#> 5 0.43 1.70 0.34
#> 6 0.51 -0.91 0.95
#> 7 -0.57 1.94 0.26
#> 8 -0.55 0.78 2.51
#> 9 -0.56 0.33 1.55
#> 10 -0.89 1.45 0.16
#> # ... with 990 more draws
#> # ... hidden reserved variables {'.chain', '.iteration', '.draw'}
Which can be chained right into spread_draws, a la:
df |>
dplyr::rename_with(\(x) gsub("b_group(.*)", "b_group[\\1]", x)) |>
tidybayes::spread_draws(b_group[i]) |>
ggdist::median_qi()
#> # A tibble: 2 × 7
#> i b_group .lower .upper .width .point .interval
#> <chr> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
#> 1 A 1.01 -0.973 2.84 0.95 median qi
#> 2 B 2.06 0.101 4.02 0.95 median qi
I would like to extract draws from a model that has an intercept that varies with a predictor with
spread/gather_draws
.When the model is specified as
y ~ 0 + group
it results in a brms model with variables"b_groupA"
and"b_groupB"
. But it is not possible to extract the group usingm %>% gather_draws(b_group[group])
. What it is possible is to extract the intercepts using regex but not the group index variable:So far, the workaround for me has been to use
separate_wider_regex
but things get complex quite quickly.Do you think your functions can be adapted to cover such scenario?
Code: