cscherrer / Soss.jl

Probabilistic programming via source rewriting
https://cscherrer.github.io/Soss.jl/stable/
MIT License
414 stars 30 forks source link

For(...) do i ... broken using Soss with MeasureTheory.jl@v0.11.3 or later #305

Closed Tuebel closed 2 years ago

Tuebel commented 2 years ago

I'm not sure if this issue is best suited for this repo or MeasureTheory or MeasureBase.

This issue only occurs when using MeasureTheory.jl@v0.11.3 or later which bumped MeasureBase.jl to v0.3. Using MeasureTheory.jl@v0.11.2 works fine. I think, that is also about the time where ProductMeasure was ported to MeasureBase.

When using the For(...) do i ... syntax to create a ProductMeasure in a Soss model I get an error:

model = @model begin
    a ~ For(3) do x Normal(μ=x) end
end
MethodError: no method matching productmeasure(::ggfunc-function, ::UnitRange{Int64})
Closest candidates are:
  productmeasure(::Any, ::Any, ::Any) at /home/rd/.julia/packages/MeasureBase/XtMtF/src/combinators/smart-constructors.jl:69
  productmeasure(::Kernel, ::Any) at /home/rd/.julia/packages/MeasureBase/XtMtF/src/combinators/smart-constructors.jl:75
  productmeasure(::Function, ::Any) at /home/rd/.julia/packages/MeasureBase/XtMtF/src/combinators/smart-constructors.jl:77
  ...

Stacktrace:
  [1] For(f::ggfunc-function, n::Int64)
    @ MeasureBase ~/.julia/packages/MeasureBase/XtMtF/src/combinators/for.jl:82
  [2] macro expansion
    @ ~/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:116 [inlined]
  [3] (::ggfunc-function)(::Matrix{Float64}, ::Random._GLOBAL_RNG; pkwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ GeneralizedGenerated.NGG ~/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83
  [4] (::ggfunc-function)(::Matrix{Float64}, ::Vararg{Any, N} where N)
    @ GeneralizedGenerated.NGG ~/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83
  [5] (::GeneralizedGenerated.Closure{function = (x, _rng;) -> begin
    begin
        θ = (Main).mean(x)
        λ_2 = (Main).rand(_rng, (Main).Exponential(θ))
        λ_1 = (Main).rand(_rng, (Main).Exponential(θ))
        a = (Main).rand(_rng, (Main).For(function = (x;) -> begin
    (Main).Normal(μ = x)
end, 3))
        N = (Main).length(x)
        (θ = θ, N = N, λ_2 = λ_2, λ_1 = λ_1, a = a)
    end
end, Tuple{Matrix{Float64}}})(args::Random._GLOBAL_RNG; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ GeneralizedGenerated ~/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6
  [6] (::GeneralizedGenerated.Closure{function = (x, _rng;) -> begin
    begin
        θ = (Main).mean(x)
        λ_2 = (Main).rand(_rng, (Main).Exponential(θ))
        λ_1 = (Main).rand(_rng, (Main).Exponential(θ))
        a = (Main).rand(_rng, (Main).For(function = (x;) -> begin
    (Main).Normal(μ = x)
end, 3))
        N = (Main).length(x)
        (θ = θ, N = N, λ_2 = λ_2, λ_1 = λ_1, a = a)
    end
end, Tuple{Matrix{Float64}}})(args::Random._GLOBAL_RNG)
    @ GeneralizedGenerated ~/.julia/packages/GeneralizedGenerated/PV9u7/src/closure.jl:6
  [7] rand
    @ ~/.julia/packages/Soss/89FWt/src/primitives/rand.jl:20 [inlined]
  [8] rand(m::Soss.ConditionalModel{NamedTuple{(:x,), T} where T<:Tuple, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}, NamedTuple{(:x,), Tuple{Matrix{Float64}}}, NamedTuple{(), Tuple{}}})
    @ Soss ~/.julia/packages/Soss/89FWt/src/primitives/rand.jl:24
  [9] top-level scope
    @ In[11]:2
 [10] eval
    @ ./boot.jl:360 [inlined]
 [11] include_string(mapexpr:

However, calling rand on the ProductMeasure is successful so no tests would fail in MeasureTheory.jl

a = For(3) do x Normal(μ=x) end
rand(a)
(a = [0.7378719090151773, 1.2859026155951314, 2.451153694681324],)

Thanks for this library by the way :)

cscherrer commented 2 years ago

Thanks for letting me know about this! This is a high priority to fix, obviously. I think (hope?) it should be a quick fix, then we should add tests to be sure to catch this going forward. I'll spend some time on it today.

cscherrer commented 2 years ago

Here's a MWE, removed from Soss:

julia> using Soss, MeasureTheory, GeneralizedGenerated

julia> g = GG.mk_function(:(x -> Normal(x,1)))
function = (x;) -> begin
    begin
        (GeneralizedGenerated).Normal(x, 1)
    end
end

julia> For(g, 3)
ERROR: MethodError: no method matching productmeasure(::ggfunc-function, ::UnitRange{Int64})
Closest candidates are:
  productmeasure(::Any, ::Any, ::Any) at /home/chad/git/MeasureBase/src/combinators/smart-constructors.jl:69
  productmeasure(::Kernel, ::Any) at /home/chad/git/MeasureBase/src/combinators/smart-constructors.jl:75
  productmeasure(::Returns, ::Any, ::Any) at /home/chad/git/MeasureBase/src/combinators/smart-constructors.jl:82
  ...
Stacktrace:
 [1] For(f::ggfunc-function, n::Int64)
   @ MeasureBase ~/git/MeasureBase/src/combinators/for.jl:82
 [2] top-level scope
   @ REPL[53]:1

mk_function calls GeneralizedGenerated.closure_conv, which produces

julia> typeof(g)
ggfunc-function

julia> typeof(g) |> supertype
Any

So the issue is that MeasureBase.productMeasure doesn't have a generic method, but wants a <:Function, and a ggfun-function doesn't meet this criteria.

The simplest fix for this is to add a method

productmeasure(f, pars) = productmeasure(f, identity, pars)

For example, this works:

julia> MeasureBase.productmeasure(f, pars) = MeasureBase.productmeasure(f, identity, pars)

julia> m = @model begin 
           x ~ For(3) do j Normal(j,1) end
       end;

julia> rand(m())
(x = [1.9296440232404293, 3.1574221730525336, 4.1352330633979255],)

But it turns out this just kicks the can down the road:

julia> logpdf(m(), rand(m()))
ERROR: MethodError: no method matching basekernel(::ggfunc-function)
Closest candidates are:
  basekernel(::Kernel) at /home/chad/git/MeasureBase/src/kernel.jl:73
  basekernel(::Returns) at /home/chad/git/MeasureBase/src/kernel.jl:74
  basekernel(::Function) at /home/chad/git/MeasureBase/src/kernel.jl:71

To be continued...

cscherrer commented 2 years ago

Hi @Tuebel , thanks again for this issue. There's a new release with tests that include

    @testset "https://github.com/cscherrer/Soss.jl/issues/305" begin
        m = @model begin 
            x ~ For(3) do j Normal(μ=j) end
        end;

        @test logpdf(m(), rand(m())) isa Float64
    end
Tuebel commented 2 years ago

Thanks for the quick fix!