Closed torfjelde closed 2 months ago
It looks like a good one for @mhauru to review.
I was just today reading through DynamicPPL and noticed that model.context
is a thing, and had to take a while to figure out why, given the explicit passing around of a separate context.
Do I understand correctly that functionally the changes here are equivalent to changing
context::AbstractContext=model.context
to
context::AbstractContext=leafcontext(model.context)
and the rest, introducing the nothing
and the getcontext
function, are an aesthetic preference for not having the LogDensityFunction
struct store redundant data?
I think we should unify these contexts eventually, although not necessarily in this PR.
I lean towards contextualising a model before passing it to a evaluate!!
function:
# check for invalid context composition; note that `contextualising!!` could be called more than once
model_with_context = contextualising!!(model, context)
res = evaluate!!(rng, model_with_context, ...) # remove context argument here
If a model is conditioned, when we contextualise it again, it can throw an error in cases where context composition is invalid.
This is probably the same as @torfjelde's idea above, removing the context
argument from evaluate!!
completely but introducing an explicit contextualising!!
function.
and the rest, introducing the nothing and the getcontext function, are an aesthetic preference for not having the LogDensityFunction struct store redundant data?
It's not so much about "not storing redundant data", but rather about "lazily" resolving the context in the case of f.context === nothing
and passing leaftcontext(model.context)
is effectively a no-op.
This is probably the same as @torfjelde's idea above, removing the context argument from evaluate!! completely but introducing an explicit contextualising!! function.
We already have this: contextualize(model, context)
. But otherwise, yes we're on the same page: replace all calls to evaluate!!(model, varinfo, context)
with
evaluate!!(contextualize(model, context), varinfo)
But, as I said above, this isn't so easy because we use explicit context-passing quite a few places beyond evaluate!!
π
But, as I said above, this isn't so easy because we use explicit context-passing quite a few places beyond evaluate!! π
Is there any other difficulty other than finding all the places and then updating them?
Is there any other difficulty other than finding all the places and then updating them?
It's a question of what you do with methods such as getindex(varinfo, context)
(which are there because we need to let samplers pick out slices of varinfo
as we haven't yet replaced the Gibbs sampler). Once we have replaced the Gibbs sampler fully, we'll be able to drop a lot of these methods where we use context
/ sampler
and then it'll indeed be "just finding all the places and replacing them"
Something weird is happening with Documenter.jl here? Seems like everything is missing some whitespace o.O
Changes Missing Coverage | Covered Lines | Changed/Added Lines | % | ||
---|---|---|---|---|---|
src/logdensityfunction.jl | 7 | 8 | 87.5% | ||
<!-- | Total: | 7 | 8 | 87.5% | --> |
Totals | |
---|---|
Change from base Build 9495604846: | -0.01% |
Covered Lines: | 2759 |
Relevant Lines: | 3434 |
This should be ready
The integration test fails because Turing.jl has code that tries to access the LogDensityFunction.context
field directly and is not prepared to handle it being nothing
. This may be a more prevalent issue, i.e. this PR may be breaking for many dependants.
One solution to this would be to go and fix all the dependants to use getcontext(x)
rather than x.context
. This would be long-term nicer IMO, but would require making edits to all those packages, and might cause breakage in the meanwhile if dependants don't specify version upper bounds for DynamicPPL. It would also mean that those packages then don't work with older versions of DynamicPPL anymore.
Another option would be to override getproperty
for LogDensityFunction
, so that x.context
would actually return getcontext(x)
. This is may be a bit overkill in terms of using a somewhat "deep" Julia construct (getproperty
) to solve quite a simple problem, but it would be minimally disruptive.
Relevant Julia style guide page: https://docs.julialang.org/en/v1/manual/style-guide/#Prefer-exported-methods-over-direct-field-access
I lean towards the former solution. Other thoughts?
One solution to this would be to go and fix all the dependants to use getcontext(x) rather than x.context. This would be long-term nicer IMO, but would require making edits to all those packages, and might cause breakage in the meanwhile if dependants don't specify version upper bounds for DynamicPPL.
@mhauru, can you create a PR for Turing that adopts the suggestion you propose above? For packages without DynamicPPL bounds, that's unfortunate, maybe this is the opportunity that such bounds are added. However, I am not aware of any package depending on DynamicPPL without an explicit version bound.
Also, does that mean this PR can be merged as a breaking release?
Yep, we can make this a breaking release and be fine. I'll make the Turing.jl PR tomorrow.
Changes Missing Coverage | Covered Lines | Changed/Added Lines | % | ||
---|---|---|---|---|---|
src/logdensityfunction.jl | 6 | 8 | 75.0% | ||
<!-- | Total: | 6 | 8 | 75.0% | --> |
Files with Coverage Reduction | New Missed Lines | % | ||
---|---|---|---|---|
src/logdensityfunction.jl | 1 | 61.9% | ||
src/sampler.jl | 1 | 94.12% | ||
ext/DynamicPPLForwardDiffExt.jl | 1 | 77.78% | ||
src/contexts.jl | 2 | 77.27% | ||
src/threadsafe.jl | 4 | 50.0% | ||
src/abstract_varinfo.jl | 5 | 82.68% | ||
src/context_implementations.jl | 8 | 58.63% | ||
src/model_utils.jl | 10 | 19.64% | ||
src/loglikelihoods.jl | 16 | 54.84% | ||
src/varinfo.jl | 55 | 85.36% | ||
<!-- | Total: | 103 | --> |
Totals | |
---|---|
Change from base Build 9495604846: | -3.7% |
Covered Lines: | 2642 |
Relevant Lines: | 3448 |
Changes Missing Coverage | Covered Lines | Changed/Added Lines | % | ||
---|---|---|---|---|---|
src/logdensityfunction.jl | 6 | 8 | 75.0% | ||
<!-- | Total: | 6 | 8 | 75.0% | --> |
Files with Coverage Reduction | New Missed Lines | % | ||
---|---|---|---|---|
src/logdensityfunction.jl | 1 | 61.9% | ||
src/sampler.jl | 1 | 94.12% | ||
ext/DynamicPPLForwardDiffExt.jl | 1 | 77.78% | ||
src/contexts.jl | 2 | 77.27% | ||
src/threadsafe.jl | 4 | 50.0% | ||
src/abstract_varinfo.jl | 5 | 82.68% | ||
src/context_implementations.jl | 8 | 58.63% | ||
src/model_utils.jl | 10 | 19.64% | ||
src/loglikelihoods.jl | 16 | 54.84% | ||
src/varinfo.jl | 55 | 85.36% | ||
<!-- | Total: | 103 | --> |
Totals | |
---|---|
Change from base Build 9495604846: | -3.7% |
Covered Lines: | 2642 |
Relevant Lines: | 3448 |
Changes Missing Coverage | Covered Lines | Changed/Added Lines | % | ||
---|---|---|---|---|---|
src/logdensityfunction.jl | 6 | 8 | 75.0% | ||
<!-- | Total: | 6 | 8 | 75.0% | --> |
Files with Coverage Reduction | New Missed Lines | % | ||
---|---|---|---|---|
src/logdensityfunction.jl | 1 | 61.9% | ||
src/sampler.jl | 1 | 94.12% | ||
ext/DynamicPPLForwardDiffExt.jl | 1 | 77.78% | ||
src/contexts.jl | 2 | 77.27% | ||
src/threadsafe.jl | 4 | 50.0% | ||
src/abstract_varinfo.jl | 5 | 82.68% | ||
src/context_implementations.jl | 8 | 58.63% | ||
src/model_utils.jl | 10 | 19.64% | ||
src/loglikelihoods.jl | 16 | 54.84% | ||
src/varinfo.jl | 55 | 85.36% | ||
<!-- | Total: | 103 | --> |
Totals | |
---|---|
Change from base Build 9495604846: | -2.8% |
Covered Lines: | 2658 |
Relevant Lines: | 3427 |
Thanks for getting this through!
Another option would be to override getproperty for LogDensityFunction, so that x.context would actually return getcontext(x). This is may be a bit overkill in terms of using a somewhat "deep" Julia construct (getproperty) to solve quite a simple problem, but it would be minimally disruptive.
Agree that this would have been overkill:)
Issue
When evaluating a
Model
, there are two "sources" of contexts provided: 1) explicitly passed toevaluate!!
as an argument, and 2) through the context attached to the model itself inmodel.context
.The latter was introduced because in many scenarios it makes sense to "contextualize" a
Model
, e.g. attach aConditionContext
to aModel
to specify which parameters are consideredconditioned
. The former is present because back in the day when the samplers were heavily tied to DynamicPPL and we passed asampler
argument in almost every place where we now passcontext
.To "bridge" the two approaches, when we call
evaluate!!
, the process of "resolving" the context that eventually ends up as__context__
in the model itself, occurs here:https://github.com/TuringLang/DynamicPPL.jl/blob/d384da21e168ca089f16551231913700a4454efc/src/model.jl#L995-L997
In short, we do:
model.context
with the leaf context provided bycontext
argument.context
argument to the context resulting from (1).This might be a bit strange, but the result is that
context
takes precedence overmodel.context
, as it's considered to be "more important" due to the "user" explicitly passing it toevaluate!!
.We did this because some samplers were using contexts to specify certain behaviors that had to be respected, e.g.
context
could be aPriorContext
to indicate that the prior should be evaluated whilemodel.context
could be aDefaultContext
, in which case we wanted the result to bePriorContext
.This also means that
LogDensityFunction
, effectively a convenient wrapper aroundevaluate!!
, also has two sources of contexts:f.model.context
andf.context
. By default, i.e. if we callLogDensityFunction(model)
, we specify these to be the same, i.e.f.context === f.model.context
. This is clearly very redundant, since we're just specifiying the same context twice. Moreover, since, as seen above, we effectively concatenatecontext
andmodel.context
, this results inLogDensityFunction
evaluating the model with the context "doubled". In most cases, this still results in the intended behavior, but once you start changing certain fields of theLogDensityFunction
, e.g.LogDensityFunction(model_new, f.varinfo, f.context)
, interesting things can happen. For example, in https://github.com/TuringLang/Turing.jl/pull/2231, I ran into an issue where I'd get twoConditionContext
conditioning the same variable; one fromf.model.context
(what I intended) and one fromf.context
(what I did not intend).Solution
This PR addresses this issue for
LogDensityFunction
by simply allowingnothing
inf.context
, which is resolved toleafcontext(model.context)
if not specified. This addresses the issues I've encountered above, but the proper way of fixing this would, IMO, be to either:nothing
to be passed in place ofcontext
"everywhere", i.e. we make them allOptional{AbstractContext} = Union{Nothing,AbstractContext}
types, and resolve tomodel.context
whenever it'snothing
.context
argument fromevaluate!!
completely and instead always just "attach" thecontext
to theModel
. This seems "nicer" overall, but will require quite a bit of work both here and on the Turing.jl side + it's not quite clear to me that this will indeed quite work (context
is used many other places than just inevaluate!!
, e.g.unflatten
, to allow samplers inSamplingContext
to change behaviors further).Appendum
This entire PR arose from the following scenario:
A
ConditionContext
is such that the "outermost" one takes precedence (since this is the one which was applied last), but in the above scenario this is not respected since we end up using theConditionContext(x = 0, ...)
fromf.context
instead of the outermost one fromf.model.context
.