stan-dev / posterior

The posterior R package
https://mc-stan.org/posterior/
Other
167 stars 24 forks source link

Dealing with chains of different length #334

Open n-kall opened 10 months ago

n-kall commented 10 months ago

I suppose there is an implicit expectation that all chains are the same length, but different lengths can occur when doing filtering with e.g. dplyr. I think the recommended workflow should likely be to merge chains before doing this kind of filtering, but I guess that can't be enforced.

Currently, different draws formats handle chains of varying length differently in how they calculate niterations (even after repairing). And draws_array does not allow different length chains at all -> I think the error message could point to merge_chains as a solution.

Consider filtering a draws_df based on the value of some variable (below dplyr is used to do so, but it could be done in base R syntax). This can end up with a draws_df that has different length chains.

library(posterior)
library(dplyr)
fdraws <- example_draws() |>
as_draws_df() |>
dplyr::filter(`mu` > 0) |>
repair_draws()

niterations(fdraws)
## [1] 94

niterations(as_draws_matrix(fdraws))
## [1] 89.5

niterations(as_draws_list(fdraws))
## [1] 90

as_draws_array(fdraws)
## Error in abind::abind(x, along = 3L) : 
##   arg 'X2' has dims=86, 10, 1; but need dims=90, 10, X

Should there be a message/warning given when a draws object has different length chains that suggests merge_chains?

mjskay commented 10 months ago

Hmm yeah, to be consistent with other cases where this comes up, conversion functions from draws_df (which I believe is the only format that supports unequal chain lengths) to other formats should probably check for this. When it happens, it should either raise an error, or merge chains and warn with warn_merge_chains("index").