Closed cpfiffer closed 4 years ago
Merging #16 into master will increase coverage by
2.48%
. The diff coverage is100.00%
.
@@ Coverage Diff @@
## master #16 +/- ##
==========================================
+ Coverage 92.75% 95.23% +2.48%
==========================================
Files 1 1
Lines 69 84 +15
==========================================
+ Hits 64 80 +16
+ Misses 5 4 -1
Impacted Files | Coverage Δ | |
---|---|---|
src/AbstractMCMC.jl | 95.23% <0.00%> (+2.48%) |
:arrow_up: |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact)
,ø = not affected
,? = missing data
Powered by Codecov. Last update 4bfe617...7ec741f. Read the comment docs.
IMO since this function is basically doing the same as the current implementation of sample
just without specifying a fixed number of iterations, it should be called sample
as well. The advantage of sample
is also that it is not even owned by AbstractMCMC
but just implements the method from StatsBase
for specific models and samplers.
In general, I do not think AbstractMCMC should try to cover all possible use cases and implementations of sample
. It is straightforward to just modify or adapt the default implementation of AbstractMCMC.sample
in downstream packages, and IMO that is where modifications or implementations should live that are not of general interest and/or reused in multiple places (and even then, one could introduce packages like AbstractNestedSamplers
that provide common functionality for nested samplers, e.g.). Moreover, AbstractMCMC
already provides an implementation of an iterator for which the number of iterations is not fixed a priori?
More concretely, in this case I think something similar to Flux's default implementation of train!
would maybe be preferable. Instead of users/developers having to implement a done_sampling
function, one could just provide (basically copied from Flux)
struct StopException <: Exception end
function stop()
throw(StopException())
end
and switch to a more general handling of callbacks (I have a draft laying around on my computer...). Then users can just call AbstractMCMC.stop()
in their callback, and if AbstractMCMC catches a StopException
sampling would be terminated. This approach would be more flexible since a user is able to provide different callbacks for the same types of models and samplers (with the current approach there can be only one implementation for done_sampling
for a fixed set of types) and additionally we do not try to catch the very general InterruptException
anymore (e.g., if a user cancels execution of a cell in a notebook, we would catch this exception as well but not rethrow it, so some outer process might continue running).
IMO since this function is basically doing the same as the current implementation of
sample
just without specifying a fixed number of iterations, it should be calledsample
as well. The advantage ofsample
is also that it is not even owned byAbstractMCMC
but just implements the method fromStatsBase
for specific models and samplers.
Yeah, I can change the name.
In general, I do not think AbstractMCMC should try to cover all possible use cases and implementations of
sample
. It is straightforward to just modify or adapt the default implementation ofAbstractMCMC.sample
in downstream packages, and IMO that is where modifications or implementations should live that are not of general interest and/or reused in multiple places (and even then, one could introduce packages likeAbstractNestedSamplers
that provide common functionality for nested samplers, e.g.). Moreover,AbstractMCMC
already provides an implementation of an iterator for which the number of iterations is not fixed a priori?
Sure, I'd agree with this general point. But I also think that there have been frictions with how sampling should work when N
is not predefined, and I don't think that people should be hacking that hard on AbstractMCMC to make that work. I don't think people should be overloading sample
-- the whole point of sample
is that you don't touch it, only the methods it calls. It's an interface, it's supposed to make things effortless, which I think this PR does pretty well.
I was notified of a pain point that comes from trying to shoehorn inference methods into an MCMC framework. Sure, nested sampling is not really MCMC, but it does implement all the methods in AbstractMCMC in a very idiomatic way. The methods here are actually pretty nice to work with even if you do not have an MCMC method. I don't think we should force these other inference methods to rewrite sample
or define their own calls when we can do so in a very general way at the outset.
and switch to a more general handling of callbacks (I have a draft laying around on my computer...). Then users can just call
AbstractMCMC.stop()
in their callback, and if AbstractMCMC catches aStopException
sampling would be terminated. This approach would be more flexible since a user is able to provide different callbacks for the same types of models and samplers (with the current approach there can be only one implementation fordone_sampling
for a fixed set of types) and additionally we do not try to catch the very generalInterruptException
anymore (e.g., if a user cancels execution of a cell in a notebook, we would catch this exception as well but not rethrow it, so some outer process might continue running).
Sure, I'm fine with this, but I'd prefer to wait until your draft becomes a thing.
I don't think we should force these other inference methods to rewrite sample or define their own calls when we can do so in a very general way at the outset
Yes, I agree. And I think specifying a sampling method without given number of iterations might be such a general and useful implementation. I only wanted to point out that in my opinion we shouldn't start thinking about every possible use case :slightly_smiling_face:
Sure, I'm fine with this, but I'd prefer to wait until your draft becomes a thing.
The draft is part of https://github.com/TuringLang/AbstractMCMC.jl/pull/17.
RE the namespacing and naming issues:
Could we set up a version of StatsBase.sample
that doesn't have N
as its argument and this then dispatches to whatever other method instead of splitting between train
, infer
and sample
.
assuming this doesn't upset the natural dispatch flow, this would introduce the fewest new names into the api.
PS @devmotion you can see at my repo how I've fit nested sampling into the AbstractMCMC framework. As Cameron pointed out, it fits the AbstractMCMC idioms nearly perfectly except for this N
business.
Sounds good to me. When I get home this evening I'll change it over to use the callback functionality.
I made a couple changes here, sorry @yebai, I think I've invalidated your suggestions above.
infer
is now just sample
without the N
argument. done_sampling
throws a StopException
when convergence is met, and it will terminate sampling early by default if the user has not overloaded done_sampling
. infer
/sample
supports callbacks now.I also decided to leave done_sampling
as a separate function rather than to assume that it should go through the callbacks, particularly because I want something that is very well defined and distinct from the callbacks, which I think of as "extra stuff" that someone might want to occur during sampling. done_sampling
is a requirement for using this method and callbacks are not.
Hmm I'm still not convinced that one should have to define a done_sampling
function, mainly due to two reasons: it restricts one to only one stopping condition for a specific type of models and samplers (if one does not redefine done_sampling
) and there is not a useful default implementation (one always has to implement a stopping condition anyways).
However, I can see the point that one maybe wants to separate the stopping condition from a regular callback more clearly, so an alternative, maybe superior approach would be to provide an arbitrary function that implements the stopping condition as an additional argument (basically instead of the number of iterations N
). That would remove the need for StopException
, enable the use of arbitrary stopping conditions for a specific type of models and samplers, and also separate the stopping condition from regular callbacks.
Hmm I'm still not convinced that one should have to define a
done_sampling
function, mainly due to two reasons: it restricts one to only one stopping condition for a specific type of models and samplers (if one does not redefinedone_sampling
) and there is not a useful default implementation (one always has to implement a stopping condition anyways).
I think the reason I want this to be distinct is that you should be forced to think quite hard about when/whether you want your inference algorithm to support this functionality. If you are working with a traditional MCMC algorithm, this doesn't make much sense unless you have some well defined convergence logic built in.
Plus, you can define "families" of inference algorithms that have their own defaults, so you don't have to do this for every possible algorithm. Let's say you want to disable this functionality for MCMC:
# Defaults for MCMC algorithms.
abstract type MCMCSampler <: AbstractSampler end
struct MetropolistHastings <: MCMCSampler end
function done_sampling(
rng::AbstractRNG,
model::AbstractModel,
s::MetropolistHastings,
transitions,
iteration::Int;
chain_type::Type=Any,
kwargs...
)
@warn("Convergence checks not supported for MCMC algorithms.")
throw(StopException())
end
Then, if you've got a family of algorithms like nested sampling, you can build this categorical check:
# Defualts for sampler families that have some convergence aspect.
abstract type OtherSampler <: AbstractSampler end
struct NestedSampler <: OtherSampler end
function done_sampling(
rng::AbstractRNG,
model::AbstractModel,
s::OtherSampler,
transitions,
iteration::Int;
chain_type::Type=Any,
kwargs...
)
if converged(rng, model, s, transitions)
throw(StopException())
end
end
However, I can see the point that one maybe wants to separate the stopping condition from a regular callback more clearly, so an alternative, maybe superior approach would be to provide an arbitrary function that implements the stopping condition as an additional argument (basically instead of the number of iterations
N
). That would remove the need forStopException
, enable the use of arbitrary stopping conditions for a specific type of models and samplers, and also separate the stopping condition from regular callbacks.
I'm not really sure what you mean -- do you have an example of what this could be?
Right now I am dispatching convergence criteria through the keyword args and I think that’s okay.
Do you have an example of what that looks like?
I'm not really sure what you mean -- do you have an example of what this could be?
If you want the regular callbacks being separated but still allow for arbitrary convergence criteria without overloading some done_sampling
method (since, as mentioned above, this restricts you to one convergence criterion for every combination of model type and sampler type), you could use
sample(rng, model, sampler, isdone; kwargs...)
where isdone
returns true
if the sampling should be stopped and false
otherwise (basically it is done_sampling
without StopException
). If you want to define a default convergence criterion for a certain type of samplers you can just define
sample(rng, model, sampler::MySpecialSamplers; kwargs...) = sample(rng, model, sampler, defaultisdone; kwargs...)
Updated to use a passed function.
Also, would still love to see a progress part of some form.
Right now I’ve added a ProgressThresh from ProgressMeter.jl and it works really well but requires manually passing to my is_done function to get the updated value.
Seeing the threshold and criterion converge is quite satisfying and seeing the total iterations and time after it finishes is quite nice too
Right now I’ve added a ProgressThresh from ProgressMeter.jl and it works really well but requires manually passing to my is_done function to get the updated value.
This is probably what it would be anyway, since there's not really a one-size-fits-all approach for this particular method. For the regular sample
, it's really easy to define a progress meter, but when you're doing this convergence sampling thing I'm not sure what exactly would be the most general case that serves all purposes. I suspect custom threshold reporting is probably better handled by callback
or inside your is_done
function.
I did wrap the internal stuff in logging code so you should be able to generate log messages with progress of some kind if you want to use ProgressLogging
for convergence.
In general, I'm wondering if we should completely remove the argument N
from sample_init!
, transitions_save!
, step!
, sample_end!
, and bundle_samples!
. Intuitively, it should not be needed there - indicated also by the fact that we just pass N = 1
in the iterator and the new sample!
function. That would simplify the dispatches and make the implementation a bit more intuitive, IMO.
I'm not sure if it is needed for the callback
function - in principle, it could always be created as a closure over N
or isdone
in case it is needed (actually the same with rng
, model
, sampler
, and kwargs...
). However, if we pass N
in the regular implementation of sample
, it would feel consistent to pass isdone
in the new implementation instead of N = 1
. Similarly, I think one might want to pass isdone
to transitions_init
.
Following that logic, “normal” sampling is just using a convergence method based on a max iteration, which would be an interesting paradigm shift.
I think that's wise, but I also think we've had a fair amount of breaking changes recently -- let's revisit removing N
and generally rethinking what gets dispatched where in another PR.
infer
is intended to allow for inference methods that do not have a presetN
to be sampled using AbstractMCMC's interface. This may be spun off into a different package or changed significantly, but for right now it is an experimental feature that will hopefully be more inclusive of inference methods that don't fit the "sample 1,000 times" model.The onus lies with the package developer to ensure that the sampling loop terminates, but I've added a try/catch loop that only catches
InterruptException
for the impatient who just want to stop sampling but keep the inference they've done already.@mileslucas take a look at this, particularly the testing suite. All you should need to define is a
done_sampling
function with theNestedSamplers
types. Here's an example that terminates sampling when the mean of all the previous draws is close-ish to zero:After this, you should be able to call
chain = infer(::NestedModel, ::Nested{E})
without too much trouble.