cscherrer / Soss.jl

Probabilistic programming via source rewriting
https://cscherrer.github.io/Soss.jl/stable/
MIT License
414 stars 30 forks source link

Update `testvalue` #259

Closed cscherrer closed 3 years ago

cscherrer commented 3 years ago

In @ptiede's issue https://github.com/cscherrer/Soss.jl/issues/258#issuecomment-819035325 we have

julia> m1
@model begin
        x1 ~ Soss.Normal(0.0, 1.0)
        x2 ~ Dists.MvNormal(fill(x1, 2), ones(2))
        return x2
    end

julia> testvalue(m1())
(x1 = 0.0, x2 = [0.0, 0.0])

But given the return x2, testvalue(m1()) should return a float. So if we now add

m2 = @model m begin
    μ ~ m
    y ~ For(μ) do x 
        Soss.Normal(x, 1.0)
    end
end

mm = m2(m=m1())

Then we get

julia> testvalue(mm)
ERROR: MethodError: no method matching basemeasure(::Nothing)

The problem is here:

julia> sourceBasemeasure(mm)
quote
    _bm = (;)
    _bm = merge(_bm, NamedTuple{(:μ,)}((basemeasure(m),)))
    μ = Soss.testvalue(m)
    _bm = merge(_bm, NamedTuple{(:y,)}((basemeasure(For(μ) do x
                            #= REPL[117]:4 =#
                            Soss.Normal(x, 1.0)
                        end),)))
    y = Soss.testvalue(For(μ) do x
                #= REPL[117]:4 =#
                Soss.Normal(x, 1.0)
            end)
    return Soss.ProductMeasure(_bm)
end

μ is set to the named tuple (x1 = 0.0, x2 = [0.0, 0.0]), which currently has no For method.

I think the right way to fix this is to have testvalue defined on models similarly to rand. A model can dispatch to each component, returning only the return value given by the model.

cscherrer commented 3 years ago

Fixed in https://github.com/cscherrer/Soss.jl/commit/36527519bf5da31ee77c2cb37d7dbd9172a5f360

julia> testvalue(mm)
(μ = [-0.510301547701104, 1.615423098718431], y = [0.0, 0.0])

julia> testvalue(m1())
2-element Vector{Float64}:
 -0.5655353773129189
 -0.5534287288896844

Let's keep this open until we've added unit tests