TuringLang / MCMCChains.jl

Types and utility functions for summarizing Markov chain Monte Carlo simulations
https://turinglang.org/MCMCChains.jl/
Other
268 stars 28 forks source link

Calculating acceptance rate #409

Open scheidan opened 1 year ago

scheidan commented 1 year ago

I often miss a function that computes the acceptance rate of a chain.

I'm happy to put PR together if you feel that would be a useful addition.

Some points we should think about:

cpfiffer commented 1 year ago

I suppose I could be okay with trying this out. Could you try benchmarking these on some big-ass matrices? Like show us compute times for countmap and size(unique(...)) for for different matrix sizes.

I will say though, in general the real way to do this is to calculate acceptance rates if a sampler produces an isaccept flag or whatever, which IIRC doesn't happen too much for a lot of the samplers that talk to MCMChains. The thing you proposed is probably slow but maybe a good first approximation of an acceptance rate.

scheidan commented 1 year ago

Thanks! Thinking about, we do not need (or want) to use unique or countmap. Besides of being much slower, it would also be wrong whenever we have discrete parameters.

The approach below should be correct: it simply checks for every iteration if at least one parameter has changed, i.e. a jump has happened. It should also be reasonably efficient and does not allocate (we could use threading, not sure if is is worth here).

If that looks reasonable, I will work on a PR in the next days.

# acceptance rate of a single chain
function _acceptance_rate(x)
    n_jumps = 0
    for i in axes(x, 1)[2:end]
        for j in axes(x, 2)
            if x[i,j] != x[i-1,j]
                n_jumps += 1
                break
            end
        end
    end
    n_jumps / (size(x, 1)-1)
end

function acceptance_rate(chn::Chains)
    nchains = size(chn, 3)
    ac = [_acceptance_rate(@view chn.value[:,:,i]) for i in 1:nchains]
    ac
end

## -------------
## test

n = 10000                    # chain length
d = 200                         # dimension
k = 100                         # number of chains

X = cat(randn(n÷2, d, k), ones(n÷2, d, k), dims=1);
chn = Chains(X);

acceptance_rate(chn)
@btime acceptance_rate($(chn))   # ~ 200ms, 1 allocation (900 bytes)

## just for comparison
@btime ess($chn)   # ~ 10 sec
scheidan commented 1 year ago

Would a PR to MCMCDiagnosticTools.jl make more sense?

devmotion commented 1 year ago

To me this heuristic seems to be, well, only a heuristic and a bit too brittle to be added to MCMCDiagnosticTools or MCMCChains. Even checking if all parameters are the same does not guarantee that a proposal was rejected, in particular not when working with distributions with finite discrete support. Also, similar to what @cpfiffer said above, I think acceptance rates are a rather algorithm-specific thing and not a concept that can be applied to an arbitrary Markov chain (e.g., elliptical slice sampling does not reject or accept any samples, it just returns a sequence of samples). So I think this should be addressed in e.g. AdvancedMH (https://github.com/TuringLang/AdvancedMH.jl/issues/40, https://github.com/TuringLang/AdvancedMH.jl/issues/38).

scheidan commented 1 year ago

Fair enough, but then, many things are bit heuristic when working with MCMC :)

I agree the clean solution is that the sampler returns acceptance rate. However, MCMCChains turns out to be also very useful to compare result from various samplers outside of the TuringLang domain, by converting the results (all in slightly different formats) into Chains for uniform plotting and diagnostics. In this context a heuristic acceptance rate would be much better than non (maybe with warning).

Of course it up to you do define the scope of MCMCChains. I just want to bring attention to the fact that it is maybe useful in a wider range of applications than you thought.