minaskar / zeus

⚡️ zeus: Lightning Fast MCMC ⚡️
https://zeus-mcmc.readthedocs.io/
GNU General Public License v3.0
225 stars 34 forks source link

Computation of R-hat Statistic #29

Open Bobby-Huggins opened 2 years ago

Bobby-Huggins commented 2 years ago

First of all, thanks for the package and all your hard work!

I think I've encountered a couple of issues/bugs in the computation of the R-hat statistic.

  1. First is just a typo I think. In lines 139-140 the chain means and variances are flattened by the list comprehension, where I think something like:

    _means = np.vstack(means)
    _vars = np.vstack(vars)

    will keep the structure of the chains, so that np.var(means, ddof=1, axis=0) and np.mean(_vars, axis=0) will give the between-chain and within-chain variance, respectively, across all parameters (right now they end up as scalars, since the _means and _vars are flat).

  2. This is similar in spirit to Issue #22 , but with a more significant effect here: on lines 120-121 where each split is reshaped to (-1, ndim), samples from all walkers are collapsed into each split, homogenizing them and leading to unrealistically low R-hat values. In my case I had ~28 walkers, many of which were stuck in well-separated modes and barely mixing at all, but nonetheless had a quite low split R-hat as recorded by the callback because each of the two splits had samples from all 28 walkers, making them statistically similar.

    Since this is an ensemble method I had to spend some time convincing myself, but I really think R-hat should be computed across all (possibly split) walkers, rather than by grouping them together. I can share my trace plots and make a case for this in more detail if it's helpful. With change (1) above fixing this would just be a matter of removing the reshape operation. Then nsplits would determine how many splits are made within each walker.

Thanks again, and let me know what your thoughts are. I'm happy to help implement these changes, too.