flatironinstitute / mcmc-monitor

Monitor MCMC runs in the browser
Other
34 stars 0 forks source link

update ESS calculation code to be consistent with bayes-kit #5

Open magland opened 1 year ago

magland commented 1 year ago

There are some unresolved questions here about what should be the exact formulae used. (will need to discuss with @bob-carpenter)

Here is the bayes-kit implemention, which may have some issues:

https://github.com/flatironinstitute/bayes-kit/blob/main/bayes_kit/ess.py

Here's is the current mcmc-monitor implementation, which may need to be adjusted:

https://github.com/flatironinstitute/mcmc-monitor/blob/459e8b6814c745bbf7681f3035b409177a540057/src/MCMCMonitorDataManager/stats/ess.ts

The critical functions in question are

first_neg_pair_start https://github.com/flatironinstitute/mcmc-monitor/blob/459e8b6814c745bbf7681f3035b409177a540057/src/MCMCMonitorDataManager/stats/ess.ts#L72-L82

and ess_imse https://github.com/flatironinstitute/bayes-kit/blob/22a3e9ff31f2268f47e13a737dc57c81a26ae917/bayes_kit/ess.py#L99-L135

They way mcmc-monitor does it now, the sigma_sq_hat (I think aka IAT) is never going to be less than 1, which I believe should be a desirable property. But of course we'll want to be consistent with bayes-kit.

tagging: @jsoules @WardBrian

WardBrian commented 1 year ago

I believe you will want to be able to report an ESS greater than the number of draws. This assumes the function people are interested in is the mean, but in that case anticorrelated draws (commonly produced by HMC) are actually better than independent samples

magland commented 1 year ago

I believe you will want to be able to report an ESS greater than the number of draws. This assumes the function people are interested in is the mean, but in that case anticorrelated draws (commonly produced by HMC) are actually better than independent samples

I guess that makes sense. But from what I understand, bayes-kit will essentially compute the area under the main (positve) lobe... and the only time you get IAT<1 is when there is some technical circumstance where a bit of the negative dip after the main lobe is counted as part of the area. The present mcmc-monitor code doesn't ever include that negative piece. I would be surprised if the bayes-kit version is reliably picking up anti-correlated behavior, but perhaps I am misunderstanding something.

WardBrian commented 1 year ago

I can't really speak on the bayes_kit code, just the behavior I expect from tools like stansummary (which I believe is calculated with this code)

bob-carpenter commented 1 year ago

The bayes-kit code just implements the mathematical definition from the paper as described in the Stan reference manual.

The work in Stan's implementation is delegated to the function defined in stan/analyze/mcmc/compute_effective_sample_size.hpp

There's a comment starting at line 98 which says how this is adjusted for antiautocorrleation. The monotonicity condition is implemented differently, though it may work out to the same thing.

P.S. We need to update the doc in our reference manual description of what's being computed to match the actual computation. What happened, I believe, is that we used to use the standard definition from Geyer's paper and then it was updated to deal with the kinds of antiautocorrelated chains we see with Hamiltonian Monte Carlo.

magland commented 1 year ago

I guess we should wait until the bayes_kit implementation is solidified, and then we can just copy it exactly. For now there's a working calculation for ESS that should be pretty close.

bob-carpenter commented 1 year ago

I sat down and did the algebra and I think the confusion is the offset on the pairwise positivity constrain. I originally had the wrong implementation there using pairs (1, 2), (3, 4), .... Instead, what you want to do is take pairs (0, 1), (2, 3), ...

For example, If we have a simple time series model like an AR(1) process with autocorrrelation rho in (-1,1), then we know the overall autocorrelations are

ac[-2] = rho^2
ac[-1] = rho
ac[0] = 1
ac[1] = rho
ac[2] = rho^2

Our estimator for integrated autocorrelation time is

IAT = ... + ac[-2] + ac[-1] + ac[0] + ac[1] + ac[2] + ...

but we have symmetry, ac[n] = ac[-n], so we evaluate with just the positive terms

IAT = -1 + 2 * [ ac[0] + ac[1] + ... ]

where the -1 term is to avoid double counting the lag-zero autocorrelation ac[0] = 1.

Now, if we have rho = -0.9, we'll get something like this:

IAT = -1 + 2 * [ (1 + -0.9) + (.81 + -.72) + ... ]
    < 1

Here's the behavior on the current branch, where sample_ar1 uses (lag 1) autocorrleation rho and generates a sample of size N.

>>> y = sample_ar1(rho = 0, N = 1000)
>>> len(y)
1000
>>> ess(y)
971.3854635513478
>>> 
>>> y = sample_ar1(rho = 0.5, N = 1000)
>>> ess(y)
376.41893467549767
>>> 
>>> y = sample_ar1(rho = -0.5, N = 1000)
>>> ess(y)
2999.0927300430913

So you can see in the last case that estimated ESS > N.

magland commented 1 year ago

Thanks @bob-carpenter. Okay, I will adjust our ess.ts to exactly match what's in your ess.py.

One minor note. Shouldn't the comment on first_neg_pair_start be adjusted?

https://github.com/flatironinstitute/bayes-kit/blob/40a34129a6313056bc4fedc42606578cee6dff77/bayes_kit/ess.py#L65-L67

Right now it reads

Return: index of first element whose sum with following element is negative, or the number of elements if there is no such element

But really it should be "Index of the first even-indexed element...

magland commented 1 year ago

@bob-carpenter There is still a typo in ess_imse where prevmin should be min_prev, which I guess will be resolved with PR https://github.com/flatironinstitute/bayes-kit/pull/16

But also, could you double-check the indexing there, because in this function (ess_imse) you are taking pairs (1, 2) (3, 4) etc. which seems inconsistent with first_neg_pair_start. (but maybe that's okay?)

bob-carpenter commented 1 year ago

That's a much more accurate way to doc, so I'll update. I'll also check code again. It should be checking pairs (0, 1), (1, 2), ....

The estimator without the montonic-downward constraint on pair sums will have a bit more variance, but should still be OK.

jsoules commented 1 year ago

Sorry, Bob, I'm still confused. You've said:

That's a much more accurate way to doc, so I'll update.

presumably referring to Jeremy's note that

really it should be "Index of the first even-indexed element...

But then you go on to say:

It should be checking pairs (0, 1), (1, 2),

which looks like it would be every pair, not just the even-indexed ones?

magland commented 1 year ago

@jsoules My understanding is that "That's a much more accurate way to doc, so I'll update." refers to the first_neg_pair_start function, whereas the rest of Bob's paragraph refers to ess_imse.

bob-carpenter commented 1 year ago

Ack, I mean pairs (0, 1), (2, 3), .... Doc is hard, especially before coffee.

avehtari commented 1 year ago

I noticed that the current ESS computation seems to ignore the Rhat part. For example, https://flatironinstitute.github.io/mcmc-monitor/?s=https://mcmc-monitor-proxy.herokuapp.com/s/c76d31a29d3d08e9a536&webrtc=0#/run/wu7p5aeu shows total ESS as 1140, but as Rhat is 6.95, the total ESS should be less than 4. It seems the current total ESS is just a sum of individual chain ESS's, but it should take into account if the chains are not mixing (detected by Rhat) image

bob-carpenter commented 1 year ago

That's exactly what it's doing now. I need to implement the more sophisticated R-hat estimator. That will just be a plug-in change to all the visualizations and plots.

magland commented 1 year ago

Once R-hat is implemented correctly, what should be the formula for the total ESS? Is it the sum of the individual ess's divided by R-hat, or something like that?

avehtari commented 1 year ago

See Section 3.2 in https://doi.org/10.1214/20-BA1221. For combining the information from the individual chains and Rhat the key equation is (3.10). There is a correct Python implementation in ArviZ package

magland commented 1 year ago

Thanks @avehtari. I'll read through this carefully (quite technical); I think it's important to get it right. I'm on the fence about whether or not to wait until this is implemented in bayes-kit. I'll chat with @jsoules and @bob-carpenter.

bob-carpenter commented 1 year ago

The current BayesKit only has the original R-hat algorithm and the ESS estimator that just sums across chains. In this situation the ESS estimator will be biased to the high side when R-hat >> 1.

There are three refined estimators of R-hat and different estimators of ESS:

I won't be able to get to this for a couple weeks. In terms of functionality, a new R-hat/ESS estimator should be pluggable at any point.

If you think it's critical to get a better R-hat/ESS estimator, you could try ArviZ. They have implemented a CmdStanPy interface to read in data, which is the trickiest part of ArviZ. I didn't want to include a dependency from BayesKit because ArviZ is very heavy in terms of both data structures and dependencies.