TuringLang / DynamicPPL.jl

Implementation of domain-specific language (DSL) for dynamic probabilistic programming
https://turinglang.org/DynamicPPL.jl/
MIT License
157 stars 26 forks source link

Fix for `LogDensityFunction` #621

Closed torfjelde closed 2 months ago

torfjelde commented 2 months ago

Issue

When evaluating a Model, there are two "sources" of contexts provided: 1) explicitly passed to evaluate!! as an argument, and 2) through the context attached to the model itself in model.context.

The latter was introduced because in many scenarios it makes sense to "contextualize" a Model, e.g. attach a ConditionContext to a Model to specify which parameters are considered conditioned. The former is present because back in the day when the samplers were heavily tied to DynamicPPL and we passed a sampler argument in almost every place where we now pass context.

To "bridge" the two approaches, when we call evaluate!!, the process of "resolving" the context that eventually ends up as __context__ in the model itself, occurs here:

https://github.com/TuringLang/DynamicPPL.jl/blob/d384da21e168ca089f16551231913700a4454efc/src/model.jl#L995-L997

In short, we do:

  1. Replace the leaf context of model.context with the leaf context provided by context argument.
  2. Set the child context of context argument to the context resulting from (1).

This might be a bit strange, but the result is that context takes precedence over model.context, as it's considered to be "more important" due to the "user" explicitly passing it to evaluate!!.

We did this because some samplers were using contexts to specify certain behaviors that had to be respected, e.g. context could be a PriorContext to indicate that the prior should be evaluated while model.context could be a DefaultContext, in which case we wanted the result to be PriorContext.

This also means that LogDensityFunction, effectively a convenient wrapper around evaluate!!, also has two sources of contexts: f.model.context and f.context. By default, i.e. if we call LogDensityFunction(model), we specify these to be the same, i.e. f.context === f.model.context. This is clearly very redundant, since we're just specifiying the same context twice. Moreover, since, as seen above, we effectively concatenate context and model.context, this results in LogDensityFunction evaluating the model with the context "doubled". In most cases, this still results in the intended behavior, but once you start changing certain fields of the LogDensityFunction, e.g. LogDensityFunction(model_new, f.varinfo, f.context), interesting things can happen. For example, in https://github.com/TuringLang/Turing.jl/pull/2231, I ran into an issue where I'd get two ConditionContext conditioning the same variable; one from f.model.context (what I intended) and one from f.context (what I did not intend).

Solution

This PR addresses this issue for LogDensityFunction by simply allowing nothing in f.context, which is resolved to leafcontext(model.context) if not specified. This addresses the issues I've encountered above, but the proper way of fixing this would, IMO, be to either:

  1. Allow nothing to be passed in place of context "everywhere", i.e. we make them all Optional{AbstractContext} = Union{Nothing,AbstractContext} types, and resolve to model.context whenever it's nothing.
  2. Drop the context argument from evaluate!! completely and instead always just "attach" the context to the Model. This seems "nicer" overall, but will require quite a bit of work both here and on the Turing.jl side + it's not quite clear to me that this will indeed quite work (context is used many other places than just in evaluate!!, e.g. unflatten, to allow samplers in SamplingContext to change behaviors further).

Appendum

This entire PR arose from the following scenario:

model = condition(model, x=0) # `model.context` is now a `ConditionContext(x = 0, DefaultContext())`.
f = LogDensityFunction(model)  # Both `f.model.context` and `f.context` are now `ConditionContext(x = 0, DefaultContext())`
f = Accessors.@set f.model = condition(f.model, x=1)  # `f.model.context` is now `ConditionContext(x = 1, ConditionContext(x = 0, DefaultContext()))`
LogDensityProblems.logdensity(f, params) # Here we get `f.context` wrapping `f.model.context`, i.e. `ConditionContext(x = 0, ConditionContext(x = 1, ...))`

A ConditionContext is such that the "outermost" one takes precedence (since this is the one which was applied last), but in the above scenario this is not respected since we end up using the ConditionContext(x = 0, ...) from f.context instead of the outermost one from f.model.context.

yebai commented 2 months ago

It looks like a good one for @mhauru to review.

mhauru commented 2 months ago

I was just today reading through DynamicPPL and noticed that model.context is a thing, and had to take a while to figure out why, given the explicit passing around of a separate context.

Do I understand correctly that functionally the changes here are equivalent to changing

context::AbstractContext=model.context

to

context::AbstractContext=leafcontext(model.context)

and the rest, introducing the nothing and the getcontext function, are an aesthetic preference for not having the LogDensityFunction struct store redundant data?

yebai commented 2 months ago

I think we should unify these contexts eventually, although not necessarily in this PR.

I lean towards contextualising a model before passing it to a evaluate!! function:


# check for invalid context composition; note that `contextualising!!` could be called more than once
model_with_context  = contextualising!!(model, context) 

res = evaluate!!(rng, model_with_context, ...) # remove context argument here

If a model is conditioned, when we contextualise it again, it can throw an error in cases where context composition is invalid.

This is probably the same as @torfjelde's idea above, removing the context argument from evaluate!! completely but introducing an explicit contextualising!! function.

torfjelde commented 2 months ago

and the rest, introducing the nothing and the getcontext function, are an aesthetic preference for not having the LogDensityFunction struct store redundant data?

It's not so much about "not storing redundant data", but rather about "lazily" resolving the context in the case of f.context === nothing and passing leaftcontext(model.context) is effectively a no-op.

This is probably the same as @torfjelde's idea above, removing the context argument from evaluate!! completely but introducing an explicit contextualising!! function.

We already have this: contextualize(model, context). But otherwise, yes we're on the same page: replace all calls to evaluate!!(model, varinfo, context) with

evaluate!!(contextualize(model, context), varinfo)

But, as I said above, this isn't so easy because we use explicit context-passing quite a few places beyond evaluate!! πŸ˜•

yebai commented 2 months ago

But, as I said above, this isn't so easy because we use explicit context-passing quite a few places beyond evaluate!! πŸ˜•

Is there any other difficulty other than finding all the places and then updating them?

torfjelde commented 2 months ago

Is there any other difficulty other than finding all the places and then updating them?

It's a question of what you do with methods such as getindex(varinfo, context) (which are there because we need to let samplers pick out slices of varinfo as we haven't yet replaced the Gibbs sampler). Once we have replaced the Gibbs sampler fully, we'll be able to drop a lot of these methods where we use context / sampler and then it'll indeed be "just finding all the places and replacing them"

torfjelde commented 2 months ago

Something weird is happening with Documenter.jl here? Seems like everything is missing some whitespace o.O

coveralls commented 2 months ago

Pull Request Test Coverage Report for Build 9560893923

Details


Changes Missing Coverage Covered Lines Changed/Added Lines %
src/logdensityfunction.jl 7 8 87.5%
<!-- Total: 7 8 87.5% -->
Totals Coverage Status
Change from base Build 9495604846: -0.01%
Covered Lines: 2759
Relevant Lines: 3434

πŸ’› - Coveralls
torfjelde commented 2 months ago

This should be ready

mhauru commented 2 months ago

The integration test fails because Turing.jl has code that tries to access the LogDensityFunction.context field directly and is not prepared to handle it being nothing. This may be a more prevalent issue, i.e. this PR may be breaking for many dependants.

One solution to this would be to go and fix all the dependants to use getcontext(x) rather than x.context. This would be long-term nicer IMO, but would require making edits to all those packages, and might cause breakage in the meanwhile if dependants don't specify version upper bounds for DynamicPPL. It would also mean that those packages then don't work with older versions of DynamicPPL anymore.

Another option would be to override getproperty for LogDensityFunction, so that x.context would actually return getcontext(x). This is may be a bit overkill in terms of using a somewhat "deep" Julia construct (getproperty) to solve quite a simple problem, but it would be minimally disruptive.

Relevant Julia style guide page: https://docs.julialang.org/en/v1/manual/style-guide/#Prefer-exported-methods-over-direct-field-access

I lean towards the former solution. Other thoughts?

yebai commented 2 months ago

One solution to this would be to go and fix all the dependants to use getcontext(x) rather than x.context. This would be long-term nicer IMO, but would require making edits to all those packages, and might cause breakage in the meanwhile if dependants don't specify version upper bounds for DynamicPPL.

@mhauru, can you create a PR for Turing that adopts the suggestion you propose above? For packages without DynamicPPL bounds, that's unfortunate, maybe this is the opportunity that such bounds are added. However, I am not aware of any package depending on DynamicPPL without an explicit version bound.

Also, does that mean this PR can be merged as a breaking release?

mhauru commented 2 months ago

Yep, we can make this a breaking release and be fine. I'll make the Turing.jl PR tomorrow.

github-actions[bot] commented 2 months ago

Pull Request Test Coverage Report for Build 9666211295

Details


Changes Missing Coverage Covered Lines Changed/Added Lines %
src/logdensityfunction.jl 6 8 75.0%
<!-- Total: 6 8 75.0% -->
Files with Coverage Reduction New Missed Lines %
src/logdensityfunction.jl 1 61.9%
src/sampler.jl 1 94.12%
ext/DynamicPPLForwardDiffExt.jl 1 77.78%
src/contexts.jl 2 77.27%
src/threadsafe.jl 4 50.0%
src/abstract_varinfo.jl 5 82.68%
src/context_implementations.jl 8 58.63%
src/model_utils.jl 10 19.64%
src/loglikelihoods.jl 16 54.84%
src/varinfo.jl 55 85.36%
<!-- Total: 103 -->
Totals Coverage Status
Change from base Build 9495604846: -3.7%
Covered Lines: 2642
Relevant Lines: 3448

πŸ’› - Coveralls
coveralls commented 2 months ago

Pull Request Test Coverage Report for Build 9666211295

Details


Changes Missing Coverage Covered Lines Changed/Added Lines %
src/logdensityfunction.jl 6 8 75.0%
<!-- Total: 6 8 75.0% -->
Files with Coverage Reduction New Missed Lines %
src/logdensityfunction.jl 1 61.9%
src/sampler.jl 1 94.12%
ext/DynamicPPLForwardDiffExt.jl 1 77.78%
src/contexts.jl 2 77.27%
src/threadsafe.jl 4 50.0%
src/abstract_varinfo.jl 5 82.68%
src/context_implementations.jl 8 58.63%
src/model_utils.jl 10 19.64%
src/loglikelihoods.jl 16 54.84%
src/varinfo.jl 55 85.36%
<!-- Total: 103 -->
Totals Coverage Status
Change from base Build 9495604846: -3.7%
Covered Lines: 2642
Relevant Lines: 3448

πŸ’› - Coveralls
coveralls commented 2 months ago

Pull Request Test Coverage Report for Build 9666211295

Details


Changes Missing Coverage Covered Lines Changed/Added Lines %
src/logdensityfunction.jl 6 8 75.0%
<!-- Total: 6 8 75.0% -->
Files with Coverage Reduction New Missed Lines %
src/logdensityfunction.jl 1 61.9%
src/sampler.jl 1 94.12%
ext/DynamicPPLForwardDiffExt.jl 1 77.78%
src/contexts.jl 2 77.27%
src/threadsafe.jl 4 50.0%
src/abstract_varinfo.jl 5 82.68%
src/context_implementations.jl 8 58.63%
src/model_utils.jl 10 19.64%
src/loglikelihoods.jl 16 54.84%
src/varinfo.jl 55 85.36%
<!-- Total: 103 -->
Totals Coverage Status
Change from base Build 9495604846: -2.8%
Covered Lines: 2658
Relevant Lines: 3427

πŸ’› - Coveralls
torfjelde commented 2 months ago

Thanks for getting this through!

Another option would be to override getproperty for LogDensityFunction, so that x.context would actually return getcontext(x). This is may be a bit overkill in terms of using a somewhat "deep" Julia construct (getproperty) to solve quite a simple problem, but it would be minimally disruptive.

Agree that this would have been overkill:)