cscherrer / Soss.jl

Probabilistic programming via source rewriting
https://cscherrer.github.io/Soss.jl/stable/
MIT License
414 stars 30 forks source link

Reducing reducndancy for primitive functions #250

Open cscherrer opened 3 years ago

cscherrer commented 3 years ago

"Primitive" here is a term I've been using for functions that use GeneralizedGenerated.jl to generated a function based on a Model and usually some other values. For each of these, there's a source____ function that builds the AST, for example sourceLogdensity and sourceRand.

For example, logdensity is built from

function sourceLogdensity()
    function(_m::Model)
        proc(_m, st :: Assign)     = :($(st.x) = $(st.rhs))
        proc(_m, st :: Return)     = nothing
        proc(_m, st :: LineNumber) = nothing
        function proc(_m, st :: Sample)
            x = st.x
            rhs = st.rhs
            @q begin
                _ℓ += logdensity($rhs, $x)
                $x = Soss.predict($rhs, $x)
            end
        end

        wrap(kernel) = @q begin
            _ℓ = 0.0
            $kernel
            return _ℓ
        end

        buildSource(_m, proc, wrap) |> MacroTools.flatten
    end
end

and rand is built from

function sourceRand() 
    function(_m::Model)
        proc(_m, st::Assign)  = :($(st.x) = $(st.rhs))
        proc(_m, st::Sample)  = :($(st.x) = rand(_rng, $(st.rhs)))
        proc(_m, st::Return)  = :(return $(st.rhs))
        proc(_m, st::LineNumber) = nothing

        vals = map(x -> Expr(:(=), x,x),parameters(_m)) 

        wrap(kernel) = @q begin
            _rng -> begin
                $kernel
                $(Expr(:tuple, vals...))
            end
        end

        buildSource(_m, proc, wrap) |> MacroTools.flatten
    end
end

There's clearly a lot of commonality between these, and also between the many calls to @gg:

chad@albatross ~/g/Soss.jl (dev)> rg @gg
src/importance.jl
167:@gg M function _importanceSample(_::Type{M}, p::Model, _pargs, q::Model, _qargs, _data) where M <: TypeLevel{Module}

src/simulate.jl
124:@gg M function _simulate(_::Type{M}, _m::Model, _args, trace_assignments::Val{V}) where {V, M <: TypeLevel{Module}}
131:@gg M function _simulate(_::Type{M}, _m::Model, _args::NamedTuple{()}, trace_assignments::Val{V}) where {V, M <: TypeLevel{Module}}

src/particles.jl
150:@gg M function _particles(_::Type{M}, _m::Model, _args, _n::Val{_N}) where {M <: TypeLevel{Module},_N}
156:@gg M function _particles(_::Type{M}, _m::Model, _args::NamedTuple{()}, _n::Val{_N}) where {M <: TypeLevel{Module},_N}

src/primitives/likelihood-weighting.jl
38:@gg M function _weightedSample(_::Type{M}, _m::Model, _args, _data) where M <: TypeLevel{Module}

src/primitives/rand.jl
59:@gg M function _rand(_::Type{M}, _m::Model, _args) where M <: TypeLevel{Module}
65:@gg M function _rand(_::Type{M}, _m::Model, _args::NamedTuple{()}) where M <: TypeLevel{Module}

src/primitives/logdensity.jl
43:@gg M function _logdensity(_::Type{M}, _m::Model, _args, _data, _pars) where M <: TypeLevel{Module}

src/primitives/xform.jl
148:@gg M function _xform(_::Type{M}, _m::Model{Asub,B}, _args::A, _data) where {M <: TypeLevel{Module}, Asub, A,B}

src/primitives/entropy.jl
55:@gg M function _entropy(_::Type{M}, _m::Model, _args, _n::Val{_N}) where {M <: TypeLevel{Module},_N}
61:@gg M function _entropy(_::Type{M}, _m::Model, _args::NamedTuple{()}, _n::Val{_N}) where {M <: TypeLevel{Module},_N}

src/symbolic/symbolic.jl
143:@gg M function _symlogdensity(_::Type{M}, _m::Model, ::Type{T}) where {T, M <: TypeLevel{Module}}

src/primitives/basemeasure.jl
40:@gg M function _basemeasure(_::Type{M}, _m::Model, _args, _data, _pars) where M <: TypeLevel{Module}
chad@albatross ~/g/Soss.jl (dev)> 

This makes me wonder, can we put all of this under a common higher-order function? Maybe something like

@gg M function makeprimitive(::Type{M}, _m::Model, f, post, args...)

where f takes the place of proc (since that name's not so descriptive anyway), and args... can hold whatever other arguments are passed. post is a function Expr -> Expr, which in many cases might just add some surrounding context.

Some challenges:

If it can become easier to build new primitives, this will encourage people to use this functionality. I think there's a really great potential if we can do this. Things do get tricky at this degree of abstraction, so we nede to be sure we can completely represent what we have already without losing performance.