cscherrer / Tilde.jl

WIP successor to Soss.jl
MIT License
74 stars 1 forks source link

0.2 Major Refactoring #35

Closed cscherrer closed 2 years ago

cscherrer commented 2 years ago

FixedRNG

FixedRNG() is an AbstractRNG that is, well, fixed. It always returns the same value. Like this:

julia> rand(FixedRNG())
0.5

julia> randn(FixedRNG())
0.0

julia> randexp(FixedRNG())
1.0

One possibility for this is to move it to MeasureBase (code is short), and use it to define a fallback method for testvalue.

rand

Return-related functions

Configurations (<:AbstractConfig)

I've saved the best for last. There's kind of a lot here, so let's just walk through a simple example, logdensity_def.

First we define

struct LogdensityConfig{F} <: AbstractConfig
    f::F
end

AbstractConfig subtypes are used for dispatch in Tilde, and can also hold information to be referenced as the model is run.

Next, here's the top-level call:

function MeasureBase.logdensityof(
    cm::AbstractConditionalModel,
    pars::NamedTuple;
)
    cfg = LogdensityConfig(logdensityof)
    runmodel(cfg, cm, pars, (ℓ=0,))
end

AbstractConditionalModel can be either ModelClosure (a model with specified argument) or ModelPosterior (if it has also been conditioned on some observations).

The logdensityof call is very simple: it just sets the config and then calls runmodel, which does the work. runmodel takes four arguments:

  1. The configuration
  2. An AbstractConditionalModel
  3. pars, which is just another place to specify values of more variables. This is more efficient than pulling the model apart and rebuilding it each time with different values. We expect pars to change often.
  4. A "context", which we usually shorten to ctx. This is typically a NamedTuple. It's different from a config in that a context is expected to be updated at each ~ statement in the model.

Now, if a model we're using happens to have a return statement, we need to be sure not to call it. To make this more general, we replace every return r with return retfun(cfg, r, ctx). If the model has no return statement, a model will implicitly return a named tuple of all variables appearing on the left of a ~.

Say our model has an observed value x, and contains the statement

x[j] ~ Normal()

Tilde first rewrites this (using Tilde.opticize) as

(x, @optic _[j]) ~ Normal()

Next, we often want different behavior depending whether a value has been observed. For this we have an abstract type MaybeObserved, with instances Observed and Unobserved.

Since x is observed, we represent it as Observed{:x}(x). This is a handy way to pass the information

  1. This variable has been observed
  2. Its name is :x
  3. Its value is x

Finally, the line x[j] ~ Normal() is converted to

(x, ctx) = tilde(cfg, Observed{:x}(x), (@optic _[j]), Normal(), ctx)

The tilde function needs to dispatch on the type of cfg (and possibly other values), and return an updated x and context.

In this case, that function looks like this:

@inline function tilde(
    cfg::LogdensityConfig{typeof(logdensityof)},
    obj::MaybeObserved{X},
    lens,
    d,
    ctx::NamedTuple,
) where {X}
    x = value(obj)
    pred = predict(d, lens(x))
    @reset ctx.ℓ += logdensityof(d, lens(x))
    (pred, ctx)
end

That pred line is a little weird. d might itself be a model, in which case our observations would be of the latent variables of that model. For the case d is not a model, there's a fallback method, predict(m, x) = x.

codecov[bot] commented 2 years ago

Codecov Report

:exclamation: No coverage uploaded for pull request base (main@49c3ddc). Click here to learn what that means. The diff coverage is n/a.

:exclamation: Current head edbfc10 differs from pull request most recent head 6655917. Consider uploading reports for the commit 6655917 to get more accurate results

@@           Coverage Diff           @@
##             main      #35   +/-   ##
=======================================
  Coverage        ?   42.36%           
=======================================
  Files           ?       45           
  Lines           ?     1435           
  Branches        ?        0           
=======================================
  Hits            ?      608           
  Misses          ?      827           
  Partials        ?        0           

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.