mjskay / tidybayes

Bayesian analysis + tidy data + geoms (R package)
http://mjskay.github.io/tidybayes
GNU General Public License v3.0
725 stars 59 forks source link

predicted_draws for a brms model does not bring in .chain or .iteration #301

Open rundel opened 2 years ago

rundel commented 2 years ago

My expectation is that predicted_draws() and related functions should be including the .chain and .iteration details from the model but currently they are just NA. The values can be recovered via a join but it is a bit of a headache and involves an unnecessary gather_draws() or similar.

See the reprex below,

d = data.frame(
  x = rnorm(100),
  y = rnorm(100)
)

b = brms::brm(y~x, data=d, silent=2, refresh=0)

tidybayes::predicted_draws(b, d)
#> # A tibble: 400,000 × 7
#> # Groups:   x, y, .row [100]
#>        x      y  .row .chain .iteration .draw .prediction
#>    <dbl>  <dbl> <int>  <int>      <int> <int>       <dbl>
#>  1 0.132 -0.883     1     NA         NA     1       0.573
#>  2 0.132 -0.883     1     NA         NA     2      -0.293
#>  3 0.132 -0.883     1     NA         NA     3       0.367
#>  4 0.132 -0.883     1     NA         NA     4      -0.924
#>  5 0.132 -0.883     1     NA         NA     5      -0.547
#>  6 0.132 -0.883     1     NA         NA     6       0.411
#>  7 0.132 -0.883     1     NA         NA     7      -0.138
#>  8 0.132 -0.883     1     NA         NA     8      -1.27 
#>  9 0.132 -0.883     1     NA         NA     9       1.20 
#> 10 0.132 -0.883     1     NA         NA    10      -0.361
#> # … with 399,990 more rows

This also appears to be the case with models from rstanarm as well.

mjskay commented 2 years ago

Yes, this is because posterior_predict() for those models does not return chain/iteration information. In some cases this could be retrofitted onto the resulting object, though I'm not sure it can be done in all cases (eg with subsampling). Since there wasn't a generic solution I was sure worked for all parameters that might be passed down to these functions, I opted not to include that information in case it would be wrong in some cases. I'm willing to be convinced otherwise if there's a reliable solution.

If you are using the full set of draws (no subsampling), I believe .chain should be floor(.draw/(n_draws/n_chains)) + 1 and .iteration should be .draw %% (n_draws/n_chains) + 1 but I would double check that.

rundel commented 2 years ago

Thanks for the clarification, I had not considered the issues around ndraws / draw_ids - this has lead me to do some digging in brms to wrap my head around what is happening with posterior_predict().

It seems like some of the issue is that posterior_predict() is calling prepare_predictions() which then calls posterior::as_draws_matrix() which seems to initially preserve draw "ids" but these are eventually lost due to some weirdness around how subset_draws() works and how it "repairs" the draw ids so for instances where ndraws / draw_ids are used the original ids are completely lost and the resulting draws have indexs from 1 to n. The behavior of posterior seems a bit bizarre to me but I'm sure there are reasons for these specific behaviors.

With all of that said it does seem like it would be possible to provide the .chain and .iteration information in cases where ndraws = NULL and draw_ids=NULL since it seems possible to match on the .draw to .draw in as_tibble(as_draws_df(b)) to recover the .chain and .iteration.

The formulas provided above make sense to me but seems potentially fragile if there was ever any weirdness around ordering vs. just using posterior + the brmsfit object to fill in the blanks.

rundel commented 2 years ago

One other quick though I just had - in the case of draw_ids the function(s) will already have the draw ids which case they can then be used to recover .chain and .iterations.

In the case of ndraws instead of letting prepare_predictions() handle the conversion of ndraws in `draw_ids, see here this could be done by a similar call in the tidybayes function(s). In which case the above option should again work.

mjskay commented 2 years ago

In the case of ndraws instead of letting prepare_predictions() handle the conversion of ndraws in `draw_ids, see here this could be done by a similar call in the tidybayes function(s). In which case the above option should again work.

Not a bad idea. This could be a good way to handle random subsets as well, rather than the current method which may be fragile in some cases.

WillTirone commented 1 year ago

I'm going to attempt a fix for this. I haven't dug in deep yet, but please let me know @mjskay if anything fundamental has changed since the past comments. Otherwise I'll proceed with the above! Thanks.

mjskay commented 1 year ago

Sure, would love a fix! I'd probably double check with @paul-buerkner to see if there's a canonical way to get chain and iteration info out of posterior_predict() and the like

paul-buerkner commented 1 year ago

I think this is related to https://github.com/paul-buerkner/brms/issues/1534. We likely have to wait until brms 3.0 for this feature.

WillTirone commented 1 year ago

Makes sense, thank you @paul-buerkner. We have some code that's a bit of a workaround ( see below), I assume the preference is waiting until brms 3.0 rather than a temporary fix in tidybayes?

fix_draws = function(object, newdata, ..., func = tidybayes::predicted_draws) {
  draws = func(object, newdata, ...)

  n = names(draws)

  dplyr::full_join(
    draws |> dplyr::select(-.chain, -.iteration),
    tidybayes::tidy_draws(object) |>
      dplyr::select(.chain, .iteration, .draw),
    by = ".draw"
  ) |>
    dplyr::select(dplyr::all_of(n)) |>
    dplyr::ungroup()
}