Open magland opened 1 year ago
I believe you will want to be able to report an ESS greater than the number of draws. This assumes the function people are interested in is the mean, but in that case anticorrelated draws (commonly produced by HMC) are actually better than independent samples
I believe you will want to be able to report an ESS greater than the number of draws. This assumes the function people are interested in is the mean, but in that case anticorrelated draws (commonly produced by HMC) are actually better than independent samples
I guess that makes sense. But from what I understand, bayes-kit will essentially compute the area under the main (positve) lobe... and the only time you get IAT<1 is when there is some technical circumstance where a bit of the negative dip after the main lobe is counted as part of the area. The present mcmc-monitor code doesn't ever include that negative piece. I would be surprised if the bayes-kit version is reliably picking up anti-correlated behavior, but perhaps I am misunderstanding something.
I can't really speak on the bayes_kit code, just the behavior I expect from tools like stansummary
(which I believe is calculated with this code)
The bayes-kit code just implements the mathematical definition from the paper as described in the Stan reference manual.
The work in Stan's implementation is delegated to the function defined in stan/analyze/mcmc/compute_effective_sample_size.hpp
There's a comment starting at line 98 which says how this is adjusted for antiautocorrleation. The monotonicity condition is implemented differently, though it may work out to the same thing.
P.S. We need to update the doc in our reference manual description of what's being computed to match the actual computation. What happened, I believe, is that we used to use the standard definition from Geyer's paper and then it was updated to deal with the kinds of antiautocorrelated chains we see with Hamiltonian Monte Carlo.
I guess we should wait until the bayes_kit implementation is solidified, and then we can just copy it exactly. For now there's a working calculation for ESS that should be pretty close.
I sat down and did the algebra and I think the confusion is the offset on the pairwise positivity constrain. I originally had the wrong implementation there using pairs (1, 2), (3, 4), .... Instead, what you want to do is take pairs (0, 1), (2, 3), ...
For example, If we have a simple time series model like an AR(1) process with autocorrrelation rho in (-1,1)
, then we know the overall autocorrelations are
ac[-2] = rho^2
ac[-1] = rho
ac[0] = 1
ac[1] = rho
ac[2] = rho^2
Our estimator for integrated autocorrelation time is
IAT = ... + ac[-2] + ac[-1] + ac[0] + ac[1] + ac[2] + ...
but we have symmetry, ac[n] = ac[-n]
, so we evaluate with just the positive terms
IAT = -1 + 2 * [ ac[0] + ac[1] + ... ]
where the -1 term is to avoid double counting the lag-zero autocorrelation ac[0] = 1
.
Now, if we have rho = -0.9
, we'll get something like this:
IAT = -1 + 2 * [ (1 + -0.9) + (.81 + -.72) + ... ]
< 1
Here's the behavior on the current branch, where sample_ar1
uses (lag 1) autocorrleation rho
and generates a sample of size N
.
>>> y = sample_ar1(rho = 0, N = 1000)
>>> len(y)
1000
>>> ess(y)
971.3854635513478
>>>
>>> y = sample_ar1(rho = 0.5, N = 1000)
>>> ess(y)
376.41893467549767
>>>
>>> y = sample_ar1(rho = -0.5, N = 1000)
>>> ess(y)
2999.0927300430913
So you can see in the last case that estimated ESS > N.
Thanks @bob-carpenter. Okay, I will adjust our ess.ts to exactly match what's in your ess.py.
One minor note. Shouldn't the comment on first_neg_pair_start
be adjusted?
Right now it reads
Return: index of first element whose sum with following element is negative, or the number of elements if there is no such element
But really it should be "Index of the first even-indexed element...
@bob-carpenter There is still a typo in ess_imse where prevmin should be min_prev, which I guess will be resolved with PR https://github.com/flatironinstitute/bayes-kit/pull/16
But also, could you double-check the indexing there, because in this function (ess_imse) you are taking pairs (1, 2) (3, 4) etc. which seems inconsistent with first_neg_pair_start. (but maybe that's okay?)
That's a much more accurate way to doc, so I'll update. I'll also check code again. It should be checking pairs (0, 1), (1, 2), ....
The estimator without the montonic-downward constraint on pair sums will have a bit more variance, but should still be OK.
Sorry, Bob, I'm still confused. You've said:
That's a much more accurate way to doc, so I'll update.
presumably referring to Jeremy's note that
really it should be "Index of the first even-indexed element...
But then you go on to say:
It should be checking pairs (0, 1), (1, 2),
which looks like it would be every pair, not just the even-indexed ones?
@jsoules My understanding is that "That's a much more accurate way to doc, so I'll update." refers to the first_neg_pair_start function, whereas the rest of Bob's paragraph refers to ess_imse.
Ack, I mean pairs (0, 1), (2, 3), ...
. Doc is hard, especially before coffee.
I noticed that the current ESS computation seems to ignore the Rhat part. For example, https://flatironinstitute.github.io/mcmc-monitor/?s=https://mcmc-monitor-proxy.herokuapp.com/s/c76d31a29d3d08e9a536&webrtc=0#/run/wu7p5aeu shows total ESS as 1140, but as Rhat is 6.95, the total ESS should be less than 4. It seems the current total ESS is just a sum of individual chain ESS's, but it should take into account if the chains are not mixing (detected by Rhat)
That's exactly what it's doing now. I need to implement the more sophisticated R-hat estimator. That will just be a plug-in change to all the visualizations and plots.
Once R-hat is implemented correctly, what should be the formula for the total ESS? Is it the sum of the individual ess's divided by R-hat, or something like that?
See Section 3.2 in https://doi.org/10.1214/20-BA1221. For combining the information from the individual chains and Rhat the key equation is (3.10). There is a correct Python implementation in ArviZ package
Thanks @avehtari. I'll read through this carefully (quite technical); I think it's important to get it right. I'm on the fence about whether or not to wait until this is implemented in bayes-kit. I'll chat with @jsoules and @bob-carpenter.
The current BayesKit only has the original R-hat algorithm and the ESS estimator that just sums across chains. In this situation the ESS estimator will be biased to the high side when R-hat >> 1.
There are three refined estimators of R-hat and different estimators of ESS:
Splitting chains. This is useful for one chain and it's cheap. I'd like to see a real model where it helps with multiple chains. The papers only have made-up examples.
Ranks. Reducing values to ranks helps with tails.
Using R-hat variance estimator for ESS. This penalizes ESS if R-hat is high. This requires plugging multiple chains into ESS.
I won't be able to get to this for a couple weeks. In terms of functionality, a new R-hat/ESS estimator should be pluggable at any point.
If you think it's critical to get a better R-hat/ESS estimator, you could try ArviZ. They have implemented a CmdStanPy interface to read in data, which is the trickiest part of ArviZ. I didn't want to include a dependency from BayesKit because ArviZ is very heavy in terms of both data structures and dependencies.
There are some unresolved questions here about what should be the exact formulae used. (will need to discuss with @bob-carpenter)
Here is the bayes-kit implemention, which may have some issues:
https://github.com/flatironinstitute/bayes-kit/blob/main/bayes_kit/ess.py
Here's is the current mcmc-monitor implementation, which may need to be adjusted:
https://github.com/flatironinstitute/mcmc-monitor/blob/459e8b6814c745bbf7681f3035b409177a540057/src/MCMCMonitorDataManager/stats/ess.ts
The critical functions in question are
first_neg_pair_start https://github.com/flatironinstitute/mcmc-monitor/blob/459e8b6814c745bbf7681f3035b409177a540057/src/MCMCMonitorDataManager/stats/ess.ts#L72-L82
and ess_imse https://github.com/flatironinstitute/bayes-kit/blob/22a3e9ff31f2268f47e13a737dc57c81a26ae917/bayes_kit/ess.py#L99-L135
They way mcmc-monitor does it now, the sigma_sq_hat (I think aka IAT) is never going to be less than 1, which I believe should be a desirable property. But of course we'll want to be consistent with bayes-kit.
tagging: @jsoules @WardBrian