TuringLang / AbstractMCMC.jl

Abstract types and interfaces for Markov chain Monte Carlo methods
https://turinglang.org/AbstractMCMC.jl
MIT License
87 stars 18 forks source link

`sample` equivalent but including states #84

Open torfjelde opened 3 years ago

torfjelde commented 3 years ago

Sometimes I want both the samples and the states rather than just the states. Of course this can be achieved by just using the iterator interface explicitly, or a callback, but it's a bit inconvenient to have to write this every time.

Would it make sense to introduce a sample_with_states method or simply a keyword argument to sample, e.g. include_states::Bool, specifying whether or not to also include the states in the return-value (I'm more in favour of just calling save!! with tuple (sample, state) rather than changing than returning having the kwarg making it so that we instead return samples, states at the end).

Thoughts?

cpfiffer commented 3 years ago

I'm in favor of adding the keyword argument and passing it to save!!, so instead of calling

samples = save!!(samples, sample, i, model, sampler; kwargs...)

we would call

samples = save!!(samples, (sample, state), i, model, sampler; kwargs...)

as you suggested. Then our default version of save!! would just throw out state if it was not requested. Though, now that I'm looking at this it kind of looks like there might be some type instability. Perhaps we could add a dispatch type on include_states?

At the very least, we shouldn't add a whole other sample_with_states method just to keep maintenance down, since the changes are fairly minor and it's not worth copying all the sample code around.

torfjelde commented 3 years ago

Hmm, this is actually a bit more annoying that I originally thought :confused:

As you said, it'll introduce type-instabilities unless we make it a Val typed argument or something. We could of course do that, but it's more annoying that just a kwarg.

torfjelde commented 3 years ago

Maybe we should just provide a callback? I.e.

struct StateHistoryCallback{A}
    states::A
end
StateHistoryCallback() = StateHistoryCallback(Any[])

function (cb::StateHistoryCallback)(rng, model, sampler, sample, state, i; kwargs...)
    push!(cb.states, state)
    return nothing
end

so users can do

state_history = []
sample(..., callback=StateHistoryCallback(state_history))
state_history
cpfiffer commented 3 years ago

I dunno that seems super hacky to me. I think it's valuable to provide the states but I think it's worth investing having it's own code path (whatever that looks like).