TuringLang / AbstractMCMC.jl

Abstract types and interfaces for Markov chain Monte Carlo methods
https://turinglang.org/AbstractMCMC.jl
MIT License
87 stars 18 forks source link

[RFC] Sampler states #31

Closed devmotion closed 4 years ago

devmotion commented 4 years ago

I think the current approach of dealing with states of samplers, e.g., in Turing is flawed. Currently, one defines an initial state when initializing the sampler. This is problematic since often the concrete types of the states are not known during initialization which in turn then requires setting up dummy states and states with abstractly typed fields which potentially impacts performance.

With the current API of AbstractMCMC it is possible to avoid this issue by making the sampler state a part of the transition generated by step! and then implementing transitions_init and transitions_save! such that only the part of the transitions without the sampler state is actually saved (as done for the Gibbs sampler in Turing and proposed for the DynamicNUTS sampler in https://github.com/TuringLang/Turing.jl/pull/1186). Moreover, since only the samples + statistics are stored, it is impossible to resume sampling since the state of the sampler is not known anymore.

So as mentioned in https://github.com/TuringLang/Turing.jl/pull/1186, one could switch to a setup with (more or less) stateless samplers and explicit states, something along the lines of:

function mcmcsample(
    rng::Random.AbstractRNG,
    model::AbstractModel,
    sampler::AbstractSampler,
    N::Integer;
    progress = true,
    progressname = "Sampling",
    callback = (args...; kwargs...) -> nothing,
    chain_type::Type=Any,
    kwargs...
)
    # Check the number of requested samples.
    N > 0 || error("the number of samples must be ≥ 1")

    @ifwithprogresslogger progress name=progressname begin
        # Obtain the initial state.
        state = initialstate(rng, model, sampler, N; kwargs...)

        # Run callback.
        callback(rng, model, sampler, N, 1, state; kwargs...)

        # Save the initial sample.
        samples = initsamples(state, model, sampler, N; kwargs...)
        savesample!(samples, 1, state, model, sampler, N; kwargs...)

        # Update the progress bar.
        progress && ProgressLogging.@logprogress 1/N

        # Step through the sampler.
        for i in 2:N
            # Obtain the updated state.
            state = step(rng, model, sampler, N, state; kwargs...)

            # Run callback.
            callback(rng, model, sampler, N, i, state; kwargs...)

            # Save the sample.
            savesample!(samples, i, state, model, sampler, N; kwargs...)

            # Update the progress bar.
            progress && ProgressLogging.@logprogress i/N
        end
    end

    # Compute the resulting MCMC chain.
    chain = samples2chain(rng, model, sampler, N, samples, state, chain_type; kwargs...)

    return chain, state
end

function samples2chain(
    ::Random.AbstractRNG,
    ::AbstractModel,
    ::AbstractSampler,
    ::Integer,
    samples,
    state,
    ::Type{Any};
    kwargs...
)
    return samples
end

I renamed transitions to samples and used state instead of transition to make it more clear that each individual transition should be thought of as the current state of the sampler and provides all information that is needed to resume sampling.

Maybe one could even introduce a type SamplerWithState that would make it easier to bundle a state with the corresponding sampler and return this after sampling:

struct SamplerWithState{S<:AbstractSampler,T}
    sampler::S
    state::T
end

That could possibly also reduce the amount of arguments in functions such as step or samples2chain.

As a side note, IMO the callback signature could be simplified to callback(state, i) or even just callback(state) since all other arguments are known when the callback is defined and hence can be used by the callback if it is defined as a closure over these arguments.

cpfiffer commented 4 years ago

A welcome addition, I think. A lot of the crap that happens with sampler states in Turing is basically to strongly type the sampler before it makes it to the top-level mcmcsample call, but here the state information (the only untyped information) can be removed from Sampler and typed during initialstate. I think this is a much more natural and fluid approach to this, so it's a thumbs up from me.

That said, this is a fairly significant change, and we've had (and are planning to have) several breaking changes all at once. Do you have an idea of what the downstream costs of this would be to people like @mileslucas who have implemented some form of the interface?

devmotion commented 4 years ago

That said, this is a fairly significant change, and we've had (and are planning to have) several breaking changes all at once. Do you have an idea of what the downstream costs of this would be to people like @mileslucas who have implemented some form of the interface?

You mean both a potential renaming + new methods and the changed return types? I guess the main logic of any downstream package would not have to change, the most drastic change seems the different return type. Some changes such renames should probably be handled smoothly by adding deprecations.

cpfiffer commented 4 years ago

Works for me!