blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

The ESS calculation for 1 chain #608

Closed junpenglao closed 10 months ago

junpenglao commented 10 months ago

Discussed in https://github.com/blackjax-devs/blackjax/discussions/607

Originally posted by **reubenharry** December 5, 2023 I'm opening this discussion to resolve the discrepancy between how the MCLMC repo calculates ESS and how BlackJax's mclmc algorithm does it. Because @JakobRobnik and @junpenglao have different opinions here, and I don't fully understand the details, I thought it would be better to create a discussion and tag you both. The main thing is that it would be nice to resolve this before the next version of BlackJax is published. Here is how blackjax does it: ```python flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) flat_samples = flat_samples.reshape(2, num_steps // 2, -1) ESS = effective_sample_size(flat_samples) ``` Jakob's opinion: > Following the link that you sent me (https://mc-stan.org/docs/2_18/reference-manual/effective-sample-size-section.html), they want to estimate $\rho_t$ = average over chains($\rho_t^m$) + $1-1/R^2$ > If we only have one chain, I don't see anything wrong with just taking $\rho_t^1$ > We don't really care about Gelman-Rubin here. > I am sure what effect splitting the chain will have, for sure you lose a factor of two in resolution. ... > we need to check that we have not made the estimator worse (i.e. compare both estimators as a function of samples used).