cscherrer / Soss.jl

Probabilistic programming via source rewriting
https://cscherrer.github.io/Soss.jl/stable/
MIT License
414 stars 30 forks source link

AST Models #285

Open cscherrer opened 3 years ago

cscherrer commented 3 years ago

WIP to support "AST models". A main idea here is that the user can pass a "tilde function" that replaces each

x ~ rhs

with

(x, _ctx) = tilde(Val(:x), rhs, _ctx

This makes it easy to thread a context through a model.

Second, it allows for call overloading. For example, say we have

using Soss, MeasureTheory
m = @model begin
    p ~ Uniform()
    x ~ Bernoulli(p)
end

Then we could define

function mycall(f, args...; kwargs...)
    @show f
    f(args...; kwargs...)
end

Then we would get

julia> rand(Random.GLOBAL_RNG, m; call=mycall)
f = Val{:p}
f = Uniform
f = Soss.var"#tilde#46"()
f = Val{:x}
f = Bernoulli
f = Soss.var"#tilde#46"()
(p = 0.7428791671316068, x = true)