Originally posted by **reubenharry** December 5, 2023
I'm opening this discussion to resolve the discrepancy between how the MCLMC repo calculates ESS and how BlackJax's mclmc algorithm does it. Because @JakobRobnik and @junpenglao have different opinions here, and I don't fully understand the details, I thought it would be better to create a discussion and tag you both.
The main thing is that it would be nice to resolve this before the next version of BlackJax is published.
Here is how blackjax does it:
```python
flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples)
flat_samples = flat_samples.reshape(2, num_steps // 2, -1)
ESS = effective_sample_size(flat_samples)
```
Jakob's opinion:
> Following the link that you sent me (https://mc-stan.org/docs/2_18/reference-manual/effective-sample-size-section.html), they want to estimate $\rho_t$ = average over chains($\rho_t^m$) + $1-1/R^2$
> If we only have one chain, I don't see anything wrong with just taking $\rho_t^1$
> We don't really care about Gelman-Rubin here.
> I am sure what effect splitting the chain will have, for sure you lose a factor of two in resolution.
...
> we need to check that we have not made the estimator worse (i.e. compare both estimators as a function of samples used).
Discussed in https://github.com/blackjax-devs/blackjax/discussions/607