cscherrer / Soss.jl

Probabilistic programming via source rewriting
https://cscherrer.github.io/Soss.jl/stable/
MIT License
414 stars 30 forks source link

Int argument for predict function #301

Open paschermayr opened 2 years ago

paschermayr commented 2 years ago

From: https://julialang.zulipchat.com/#narrow/stream/240884-soss.2Ejl

Could we have an additional argument in the predict function, where we specify the number of samples taken? We can do this already via

using Soss, Distributions
dat = randn(100)
μ₀ = 1.
σ₀ = 2.

m = @model n begin
    μ ~ Distributions.Normal()
    σ ~ Distributions.Exponential()
    data ~ Distributions.Normal(μ, σ) |> iid(n)
    return (; data)
end
post =  m((μ = μ₀, σ = σ₀, n = length(dat))) | (data = dat,)
vals = (μ = μ₀, σ = σ₀)

pred = predictive(m, keys(vals)...)
rand(pred(μ = μ₀, σ = σ₀, n = 1))

but not via the predict function itself:

Soss.predict(post, vals) #Will be A vector of length length(dat)
Soss.predict(post, vals, n = 1) #Method error
cscherrer commented 2 years ago

Hi @paschermayr , sorry for the delay on this. I'm tied up for a while, but I think we can get something like this working. One possibility is to have a standard way to convert a predict call into a rand call, in a way that extra arguments are handled by passing them along directly. In this way any updates to rand will come into predict as well "for free"

paschermayr commented 2 years ago

Thank you, that sounds good!