cscherrer / Soss.jl

Probabilistic programming via source rewriting
https://cscherrer.github.io/Soss.jl/stable/
MIT License
413 stars 30 forks source link

Model composition #258

Closed ptiede closed 3 years ago

ptiede commented 3 years ago

Hi, I am trying to use Soss's model composition feature (version 0.18.1) and am running into an issue. My MWE is

using Soss
using Soss: Dists

m1 = @model begin
    x1 ~ Soss.Normal(0.0, 1.0)
    x2 ~ Dists.LogNormal(0.0, 1.0)
    return x1^2/x2
end

m2 = @model m begin
    μ ~ m
    y ~ Soss.Normal(μ, 1.0)
end

mm = m2(m=m1())
xform(mm|(y=1.0,))

When I run this I get an x2 not found and get the following stack trace:

ERROR: LoadError: UndefVarError: x2 not defined
Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/GeneralizedGenerated/PV9u7/src/closure_conv.jl:132 [inlined]
 [2] _basemeasure(M::Type{GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, _m::Model{NamedTuple{(), T} where T<:Tuple, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, _args::NamedTuple{(), Tuple{}}, _data::NamedTuple{(), Tuple{}}, _pars::NamedTuple{(), Tuple{}})
   @ Soss ~/.julia/packages/GeneralizedGenerated/PV9u7/src/closure_conv.jl:132
 [3] basemeasure(c::Soss.ConditionalModel{NamedTuple{(), T} where T<:Tuple, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}, x::NamedTuple{(), Tuple{}}) (repeats 2 times)
   @ Soss ~/.julia/packages/Soss/0y6uT/src/primitives/basemeasure.jl:7
 [4] testvalue(μ::Soss.ConditionalModel{NamedTuple{(), T} where T<:Tuple, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}})
   @ MeasureTheory ~/.julia/packages/MeasureTheory/swQOn/src/utils.jl:19
 [5] macro expansion
   @ ~/.julia/packages/GeneralizedGenerated/PV9u7/src/closure_conv.jl:132 [inlined]
 [6] _xform(M::Type{GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, _m::Model{NamedTuple{(:m,), T} where T<:Tuple, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, _args::NamedTuple{(:m,), Tuple{Model{NamedTuple{(), T} where T<:Tuple, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}}}, _data::NamedTuple{(:y,), Tuple{Float64}})
   @ Soss ~/.julia/packages/GeneralizedGenerated/PV9u7/src/closure_conv.jl:132
 [7] xform(m::Soss.ConditionalModel{NamedTuple{(:m,), T} where T<:Tuple, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}, NamedTuple{(:m,), Tuple{Model{NamedTuple{(), T} where T<:Tuple, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}}}, NamedTuple{(:y,), Tuple{Float64}}})
   @ Soss ~/.julia/packages/Soss/0y6uT/src/primitives/xform.jl:22
 [8] top-level scope
   @ ~/Research/Projects/SossTest/issue.jl:16
in expression starting at /home/ptiede/Research/Projects/SossTest/issue.jl:16

Before when I tried this using Soss' master branch two weeks ago I wasn't having this issue, so I am not sure what changed between now and then. I am aware of https://github.com/cscherrer/Soss.jl/issues/245, but since this worked not that long ago I am not entirely sure this is the same problem.

cscherrer commented 3 years ago

Thanks for letting us know! Some notes so we can fix it:

First, I had thought it might be something to do with the combination of measures and distributions, but it comes up even if everything is a Measure.

The problem seems come up because

testvalue(μ::AbstractMeasure) = testvalue(basemeasure(μ))

and there's a bug in basemeasure:

julia> basemeasure(m1())
ERROR: UndefVarError: x2 not defined
Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/GeneralizedGenerated/PV9u7/src/closure_conv.jl:132 [inlined]
 [2] _basemeasure(M::Type{GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, _m::Model{NamedTuple{(), T} where T<:Tuple, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, _args::NamedTuple{(), Tuple{}}, _data::NamedTuple{(), Tuple{}}, _pars::NamedTuple{(), Tuple{}})
   @ Soss ~/.julia/packages/GeneralizedGenerated/PV9u7/src/closure_conv.jl:132
 [3] basemeasure(c::Soss.ConditionalModel{NamedTuple{(), T} where T<:Tuple, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}, x::NamedTuple{(), Tuple{}}) (repeats 2 times)
   @ Soss ~/git/Soss.jl/src/primitives/basemeasure.jl:7
 [4] top-level scope
   @ REPL[15]:1

julia> sourceBasemeasure(m1)
quote
    _bm = (;)
    _bm = merge(_bm, NamedTuple{(:x2,)}((basemeasure(Normal(0.0, 1.0)),)))
    x2 = Soss.predict(Normal(0.0, 1.0), x2)
    _bm = merge(_bm, NamedTuple{(:x1,)}((basemeasure(Normal(0.0, 1.0)),)))
    x1 = Soss.predict(Normal(0.0, 1.0), x1)
    nothing
    return ProductMeasure(_bm)
end

I think there are three things we need to do:

cc @mschauer (since we had worked through the current basemeasure together) :)

cscherrer commented 3 years ago

@ptiede I got it working on my machine, can you confirm? You'll need the master branch for Soss and MeasureTheory. Then (on my machine)

julia> using Soss

julia> using Soss: Dists

julia> m1 = @model begin
           x1 ~ Soss.Normal(0.0, 1.0)
           x2 ~ Dists.LogNormal(0.0, 1.0)
           return x1^2/x2
       end;

julia> m2 = @model m begin
           μ ~ m
           y ~ Soss.Normal(μ, 1.0)
       end;

julia> mm = m2(m=m1())
ConditionalModel given
    arguments    (:m,)
    observations ()
@model m begin
        μ ~ m
        y ~ Soss.Normal(μ, 1.0)
    end

julia> xform(mm|(y=1.0,))
TransformVariables.TransformTuple{NamedTuple{(:μ,), Tuple{TransformVariables.TransformTuple{NamedTuple{(:x2, :x1), Tuple{TransformVariables.ShiftedExp{true, Float64}, TransformVariables.Identity}}}}}}((μ = TransformVariables.TransformTuple{NamedTuple{(:x2, :x1), Tuple{TransformVariables.ShiftedExp{true, Float64}, TransformVariables.Identity}}}((x2 = asℝ₊, x1 = asℝ), 2),), 2)
ptiede commented 3 years ago

Thank you so much! It is working, but I found a new issue in basemeasure and/or testvalue in MeasureTheory. I am getting a StackOverFlow error when I change the test to

m1 = @model begin
    x1 ~ Soss.Normal(0.0, 1.0)
    x2 ~ Dists.MvNormal(fill(x1,2), ones(2))
    return x2
end

m2 = @model m begin
    μ ~ m
    y ~ For(μ) do x 
        Soss.Normal(x, 1.0)
    end
end

mm = m2(m=m1())
xform(mm|(;y=1.0,))

The stacktrace is:

ERROR: LoadError: StackOverflowError:
Stacktrace:
 [1] testvalue(μ::MeasureTheory.Lebesgue{Vector{Float64}}) (repeats 79984 times)
   @ MeasureTheory ~/.julia/packages/MeasureTheory/PfH9P/src/utils.jl:18
in expression starting at /home/ptiede/Research/Projects/SossTest/issue.jl:17
cscherrer commented 3 years ago

Should be working now, we'll add these as tests. This is really helpful!

cscherrer commented 3 years ago

Done! :)