cscherrer / Soss.jl

Probabilistic programming via source rewriting
https://cscherrer.github.io/Soss.jl/stable/
MIT License
413 stars 30 forks source link

New model transform that keeps all of the statements that its arg depends on #132

Closed millerjoey closed 4 years ago

millerjoey commented 4 years ago

Right now, prune(m,:a) (in dev branch) returns a model without any variables dependent on :a. It might also be useful to define the transform

function prior(m::Model, xs...)
    po = poset(m) #Creates a new SimplePoset, so no need to copy before mutating

    newvars = collect(xs)

    for x in xs
        append!(newvars, below(po,x))
    end

    newargs = arguments(m) ∩ newvars
    setdiff!(newvars, newargs)

    theModule = getmodule(m)
    m_init = Model(theModule, newargs, NamedTuple(), NamedTuple(), nothing)
    m = foldl(newvars; init=m_init) do m0,v
        merge(m0, Model(theModule, findStatement(m, v)))
    end
end

that returns the model as if you're designating :a (and everything it depends on) as a "prior". This should (I think) be exactly complementary to predictive so that samples from prior(m,:a) can always be fed into predictive(m,:a).

Colliders make the difference between prune and prior easier to see:

julia> m
@model begin
        x ~ Normal()
        b ~ Normal(x)
        a ~ Normal(x)
        c ~ Normal(b + a)
    end

julia> prune(m, :a)
@model begin
        x ~ Normal()
        b ~ Normal(x)
    end

julia> prior(m, :a)
@model begin
        x ~ Normal()
        a ~ Normal(x)
    end

Final thought. I'm wondering if there's a small enough vocabulary to do these types of transformations easily with minimal redundancy. Are there many useful things that prune, prior and predictive can't do?

cscherrer commented 4 years ago

Thanks @millerjoey, this is really nice! I agree minimizing redundancy is really nice, but I don't yet have a sense of what a "full" set of model combinators would look like. For example, IMO this is very nice, and important to have:

julia> [v => markovBlanket(m, v) for v in parameters(m)]
4-element Array{Pair{Symbol,B} where B,1}:
 :x => @model begin
        x ~ Normal()
        b ~ Normal(x)
        a ~ Normal(x)
    end

 :b => @model (x, a) begin
        b ~ Normal(x)
        c ~ Normal(b + a)
    end

 :a => @model (x, b) begin
        a ~ Normal(x)
        c ~ Normal(b + a)
    end

 :c => @model (a, b) begin
        c ~ Normal(b + a)
    end

This should make it relatively easy to do Gibbs sampling by alternating updates between these models. It (or something similar) should work for other message passing algorithms as well. This could be written in terms of the PPP combinators, but I don't think it would be easy. I don't yet have a sense of how close we might be, or what other combinators we might soon need. I think the best we can do is to keep in mind places to keep an eye out for opportunities to decompose new operators we'd like into smaller pieces. Then hopefully we'll start to see the right pieces fall out.

I'll try the prior you define above, hopefully we can get it added soon

cscherrer commented 4 years ago

Added here: https://github.com/cscherrer/Soss.jl/commit/fc5fdc9a2ef61ecd414f263f1f6b553221d370fc

We can work from dev and make sure we're happy with it before merging to master. Thanks! :)