TuringLang / AbstractMCMC.jl

Abstract types and interfaces for Markov chain Monte Carlo methods
https://turinglang.org/AbstractMCMC.jl
MIT License
87 stars 18 forks source link

Add `sample` method to sample until convergence #16

Closed cpfiffer closed 4 years ago

cpfiffer commented 4 years ago

infer is intended to allow for inference methods that do not have a preset N to be sampled using AbstractMCMC's interface. This may be spun off into a different package or changed significantly, but for right now it is an experimental feature that will hopefully be more inclusive of inference methods that don't fit the "sample 1,000 times" model.

The onus lies with the package developer to ensure that the sampling loop terminates, but I've added a try/catch loop that only catches InterruptException for the impatient who just want to stop sampling but keep the inference they've done already.

@mileslucas take a look at this, particularly the testing suite. All you should need to define is a done_sampling function with the NestedSamplers types. Here's an example that terminates sampling when the mean of all the previous draws is close-ish to zero:

function AbstractMCMC.done_sampling(
    rng::AbstractRNG,
    model::MyModel,
    s::MySampler,
    transitions,
    iteration::Int;
    chain_type::Type=Any,
    kwargs...
)
    # Calculate the mean of x.b.
    bmean = mean(map(x -> x.b, transitions))

    return isapprox(bmean, 0.0, atol=0.001) || iteration >= 10_000
end

After this, you should be able to call chain = infer(::NestedModel, ::Nested{E}) without too much trouble.

codecov[bot] commented 4 years ago

Codecov Report

Merging #16 into master will increase coverage by 2.48%. The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #16      +/-   ##
==========================================
+ Coverage   92.75%   95.23%   +2.48%     
==========================================
  Files           1        1              
  Lines          69       84      +15     
==========================================
+ Hits           64       80      +16     
+ Misses          5        4       -1     
Impacted Files Coverage Δ
src/AbstractMCMC.jl 95.23% <0.00%> (+2.48%) :arrow_up:

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 4bfe617...7ec741f. Read the comment docs.

devmotion commented 4 years ago

IMO since this function is basically doing the same as the current implementation of sample just without specifying a fixed number of iterations, it should be called sample as well. The advantage of sample is also that it is not even owned by AbstractMCMC but just implements the method from StatsBase for specific models and samplers.

In general, I do not think AbstractMCMC should try to cover all possible use cases and implementations of sample. It is straightforward to just modify or adapt the default implementation of AbstractMCMC.sample in downstream packages, and IMO that is where modifications or implementations should live that are not of general interest and/or reused in multiple places (and even then, one could introduce packages like AbstractNestedSamplers that provide common functionality for nested samplers, e.g.). Moreover, AbstractMCMC already provides an implementation of an iterator for which the number of iterations is not fixed a priori?

More concretely, in this case I think something similar to Flux's default implementation of train! would maybe be preferable. Instead of users/developers having to implement a done_sampling function, one could just provide (basically copied from Flux)

struct StopException <: Exception end

function stop()
    throw(StopException())
end

and switch to a more general handling of callbacks (I have a draft laying around on my computer...). Then users can just call AbstractMCMC.stop() in their callback, and if AbstractMCMC catches a StopException sampling would be terminated. This approach would be more flexible since a user is able to provide different callbacks for the same types of models and samplers (with the current approach there can be only one implementation for done_sampling for a fixed set of types) and additionally we do not try to catch the very general InterruptException anymore (e.g., if a user cancels execution of a cell in a notebook, we would catch this exception as well but not rethrow it, so some outer process might continue running).

cpfiffer commented 4 years ago

IMO since this function is basically doing the same as the current implementation of sample just without specifying a fixed number of iterations, it should be called sample as well. The advantage of sample is also that it is not even owned by AbstractMCMC but just implements the method from StatsBase for specific models and samplers.

Yeah, I can change the name.

In general, I do not think AbstractMCMC should try to cover all possible use cases and implementations of sample. It is straightforward to just modify or adapt the default implementation of AbstractMCMC.sample in downstream packages, and IMO that is where modifications or implementations should live that are not of general interest and/or reused in multiple places (and even then, one could introduce packages like AbstractNestedSamplers that provide common functionality for nested samplers, e.g.). Moreover, AbstractMCMC already provides an implementation of an iterator for which the number of iterations is not fixed a priori?

Sure, I'd agree with this general point. But I also think that there have been frictions with how sampling should work when N is not predefined, and I don't think that people should be hacking that hard on AbstractMCMC to make that work. I don't think people should be overloading sample -- the whole point of sample is that you don't touch it, only the methods it calls. It's an interface, it's supposed to make things effortless, which I think this PR does pretty well.

I was notified of a pain point that comes from trying to shoehorn inference methods into an MCMC framework. Sure, nested sampling is not really MCMC, but it does implement all the methods in AbstractMCMC in a very idiomatic way. The methods here are actually pretty nice to work with even if you do not have an MCMC method. I don't think we should force these other inference methods to rewrite sample or define their own calls when we can do so in a very general way at the outset.

and switch to a more general handling of callbacks (I have a draft laying around on my computer...). Then users can just call AbstractMCMC.stop() in their callback, and if AbstractMCMC catches a StopException sampling would be terminated. This approach would be more flexible since a user is able to provide different callbacks for the same types of models and samplers (with the current approach there can be only one implementation for done_sampling for a fixed set of types) and additionally we do not try to catch the very general InterruptException anymore (e.g., if a user cancels execution of a cell in a notebook, we would catch this exception as well but not rethrow it, so some outer process might continue running).

Sure, I'm fine with this, but I'd prefer to wait until your draft becomes a thing.

devmotion commented 4 years ago

I don't think we should force these other inference methods to rewrite sample or define their own calls when we can do so in a very general way at the outset

Yes, I agree. And I think specifying a sampling method without given number of iterations might be such a general and useful implementation. I only wanted to point out that in my opinion we shouldn't start thinking about every possible use case :slightly_smiling_face:

Sure, I'm fine with this, but I'd prefer to wait until your draft becomes a thing.

The draft is part of https://github.com/TuringLang/AbstractMCMC.jl/pull/17.

mileslucas commented 4 years ago

RE the namespacing and naming issues:

Could we set up a version of StatsBase.sample that doesn't have N as its argument and this then dispatches to whatever other method instead of splitting between train, infer and sample.

assuming this doesn't upset the natural dispatch flow, this would introduce the fewest new names into the api.

mileslucas commented 4 years ago

PS @devmotion you can see at my repo how I've fit nested sampling into the AbstractMCMC framework. As Cameron pointed out, it fits the AbstractMCMC idioms nearly perfectly except for this N business.

cpfiffer commented 4 years ago

Sounds good to me. When I get home this evening I'll change it over to use the callback functionality.

cpfiffer commented 4 years ago

I made a couple changes here, sorry @yebai, I think I've invalidated your suggestions above.

cpfiffer commented 4 years ago

I also decided to leave done_sampling as a separate function rather than to assume that it should go through the callbacks, particularly because I want something that is very well defined and distinct from the callbacks, which I think of as "extra stuff" that someone might want to occur during sampling. done_sampling is a requirement for using this method and callbacks are not.

devmotion commented 4 years ago

Hmm I'm still not convinced that one should have to define a done_sampling function, mainly due to two reasons: it restricts one to only one stopping condition for a specific type of models and samplers (if one does not redefine done_sampling) and there is not a useful default implementation (one always has to implement a stopping condition anyways).

However, I can see the point that one maybe wants to separate the stopping condition from a regular callback more clearly, so an alternative, maybe superior approach would be to provide an arbitrary function that implements the stopping condition as an additional argument (basically instead of the number of iterations N). That would remove the need for StopException, enable the use of arbitrary stopping conditions for a specific type of models and samplers, and also separate the stopping condition from regular callbacks.

cpfiffer commented 4 years ago

Hmm I'm still not convinced that one should have to define a done_sampling function, mainly due to two reasons: it restricts one to only one stopping condition for a specific type of models and samplers (if one does not redefine done_sampling) and there is not a useful default implementation (one always has to implement a stopping condition anyways).

I think the reason I want this to be distinct is that you should be forced to think quite hard about when/whether you want your inference algorithm to support this functionality. If you are working with a traditional MCMC algorithm, this doesn't make much sense unless you have some well defined convergence logic built in.

Plus, you can define "families" of inference algorithms that have their own defaults, so you don't have to do this for every possible algorithm. Let's say you want to disable this functionality for MCMC:

# Defaults for MCMC algorithms.
abstract type MCMCSampler <: AbstractSampler end
struct MetropolistHastings <: MCMCSampler end

function done_sampling(
    rng::AbstractRNG,
    model::AbstractModel,
    s::MetropolistHastings,
    transitions,
    iteration::Int;
    chain_type::Type=Any,
    kwargs...
)
    @warn("Convergence checks not supported for MCMC algorithms.")
    throw(StopException())
end

Then, if you've got a family of algorithms like nested sampling, you can build this categorical check:

# Defualts for sampler families that have some convergence aspect.
abstract type OtherSampler <: AbstractSampler end
struct NestedSampler <: OtherSampler end

function done_sampling(
    rng::AbstractRNG,
    model::AbstractModel,
    s::OtherSampler,
    transitions,
    iteration::Int;
    chain_type::Type=Any,
    kwargs...
)
    if converged(rng, model, s, transitions)
        throw(StopException())
    end
end

However, I can see the point that one maybe wants to separate the stopping condition from a regular callback more clearly, so an alternative, maybe superior approach would be to provide an arbitrary function that implements the stopping condition as an additional argument (basically instead of the number of iterations N). That would remove the need for StopException, enable the use of arbitrary stopping conditions for a specific type of models and samplers, and also separate the stopping condition from regular callbacks.

I'm not really sure what you mean -- do you have an example of what this could be?

mileslucas commented 4 years ago

Right now I am dispatching convergence criteria through the keyword args and I think that’s okay.

cpfiffer commented 4 years ago

Do you have an example of what that looks like?

devmotion commented 4 years ago

I'm not really sure what you mean -- do you have an example of what this could be?

If you want the regular callbacks being separated but still allow for arbitrary convergence criteria without overloading some done_sampling method (since, as mentioned above, this restricts you to one convergence criterion for every combination of model type and sampler type), you could use

sample(rng, model, sampler, isdone; kwargs...)

where isdone returns true if the sampling should be stopped and false otherwise (basically it is done_sampling without StopException). If you want to define a default convergence criterion for a certain type of samplers you can just define

sample(rng, model, sampler::MySpecialSamplers; kwargs...) = sample(rng, model, sampler, defaultisdone; kwargs...)
cpfiffer commented 4 years ago

Updated to use a passed function.

mileslucas commented 4 years ago

Also, would still love to see a progress part of some form.

Right now I’ve added a ProgressThresh from ProgressMeter.jl and it works really well but requires manually passing to my is_done function to get the updated value.

Seeing the threshold and criterion converge is quite satisfying and seeing the total iterations and time after it finishes is quite nice too

cpfiffer commented 4 years ago

Right now I’ve added a ProgressThresh from ProgressMeter.jl and it works really well but requires manually passing to my is_done function to get the updated value.

This is probably what it would be anyway, since there's not really a one-size-fits-all approach for this particular method. For the regular sample, it's really easy to define a progress meter, but when you're doing this convergence sampling thing I'm not sure what exactly would be the most general case that serves all purposes. I suspect custom threshold reporting is probably better handled by callback or inside your is_done function.

cpfiffer commented 4 years ago

I did wrap the internal stuff in logging code so you should be able to generate log messages with progress of some kind if you want to use ProgressLogging for convergence.

devmotion commented 4 years ago

In general, I'm wondering if we should completely remove the argument N from sample_init!, transitions_save!, step!, sample_end!, and bundle_samples!. Intuitively, it should not be needed there - indicated also by the fact that we just pass N = 1 in the iterator and the new sample! function. That would simplify the dispatches and make the implementation a bit more intuitive, IMO.

I'm not sure if it is needed for the callback function - in principle, it could always be created as a closure over N or isdone in case it is needed (actually the same with rng, model, sampler, and kwargs...). However, if we pass N in the regular implementation of sample, it would feel consistent to pass isdone in the new implementation instead of N = 1. Similarly, I think one might want to pass isdone to transitions_init.

mileslucas commented 4 years ago

Following that logic, “normal” sampling is just using a convergence method based on a max iteration, which would be an interesting paradigm shift.

cpfiffer commented 4 years ago

I think that's wise, but I also think we've had a fair amount of breaking changes recently -- let's revisit removing N and generally rethinking what gets dispatched where in another PR.