mjskay / tidybayes

Bayesian analysis + tidy data + geoms (R package)
http://mjskay.github.io/tidybayes
GNU General Public License v3.0
710 stars 59 forks source link

Request: make `add_[linpred]_draws` work with models without predictors #320

Closed JackCaster closed 3 months ago

JackCaster commented 3 months ago

I have a distributional model without predictors. For example, I have some samples from a normal distribution and I'd like to recover the parameters of the underlying distribution.

I like the add_[linpred]_draws function because it has the transform argument, which comes in handy when extracting the draws for parameters that have been transformed (e.g., sigma). However, add_[linpred]_draws expects some new data, which does not apply if the model does not have any predictors. The workaround is to create a dummy predictor and then remove it.

Would it be possible to make add_[linpred]_draws work without a new dataframe?

Code:

library(tidyverse)
library(brms)
library(tidybayes)

data  <- tibble(y = rnorm(100))

m  <- brm(y ~ 1, data = data)

fit  <- tibble( to_remove = NA)  %>%
    add_linpred_draws(m, transform = TRUE, dpar = TRUE) %>%
    ungroup()  %>%
    select(-to_remove)

fit  %>% 
    select(-.linpred)  %>% 
    median_qi()  %>% 
    as_tibble()
# A tibble: 1 × 9
      mu mu.lower mu.upper sigma sigma.lower sigma.upper .width .point .interval
   <dbl>    <dbl>    <dbl> <dbl>       <dbl>       <dbl>  <dbl> <chr>  <chr>    
1 0.0487   -0.165    0.250  1.05       0.917        1.21   0.95 median qi   
mjskay commented 3 months ago

I prefer the tidybayes functions to require an input data frame, since I think it's good to think about what data you're generating predictions for.

That said, there is a solution to your problem that doesn't require creating a data frame with a fake predictor: perhaps counterintuitively, you can create a data frame with 1 row and 0 columns.

e.g. in base R:

data.frame()[1,]
## data frame with 0 columns and 1 row

or for a tibble:

tibble::tibble(.rows = 1L)
## # A tibble: 1 × 0

which works with tidybayes functions:

data.frame()[1,] |>
  add_linpred_draws(m, transform = TRUE, dpar = TRUE)
## # A tibble: 4,000 × 7
## # Groups:   .row [1]
##     .row .chain .iteration .draw .linpred       mu sigma
##    <int>  <int>      <int> <int>    <dbl>    <dbl> <dbl>
##  1     1     NA         NA     1  0.0199   0.0199  0.886
##  2     1     NA         NA     2  0.00326  0.00326 1.01 
##  3     1     NA         NA     3  0.0232   0.0232  0.981
##  4     1     NA         NA     4 -0.187   -0.187   0.996
##  5     1     NA         NA     5  0.300    0.300   0.920
##  6     1     NA         NA     6  0.258    0.258   0.926
##  7     1     NA         NA     7  0.188    0.188   1.02 
##  8     1     NA         NA     8  0.193    0.193   1.04 
##  9     1     NA         NA     9  0.195    0.195   1.02 
## 10     1     NA         NA    10  0.105    0.105   0.915
## # ℹ 3,990 more rows
## # ℹ Use `print(n = ...)` to see more rows