cscherrer / Soss.jl

Probabilistic programming via source rewriting
https://cscherrer.github.io/Soss.jl/stable/
MIT License
412 stars 30 forks source link

fix-JET-infer #321

Closed thautwarm closed 2 years ago

thautwarm commented 2 years ago

fix JuliaStaging/GeneralizedGenerated.jl#69 : Fix the following tests:

using Soss, JET

m1 = @model N begin
    p ~ Uniform()
    x ~ For(N) do j
            Bernoulli(p / j)
        end
    end

@test_opt rand(m1(10))

m3 = @model N begin
    p ~ Uniform()
    f(ctx) = Base.Fix1(ctx) do ctx, j
        Bernoulli(ctx.p / j)
    end
    x ~ For(f((p=p,)), N)
end

@test_opt rand(m3(10))
thautwarm commented 2 years ago

I'm not really sure why creating a gg function using mkfun and calling it later does not work.

cscherrer commented 2 years ago

Thanks! I'm getting an error though, does this work for you?

julia> rand(m1(10))
:(:(ERROR: task switch not allowed from inside staged nor pure functions
Stacktrace:
  [1] try_yieldto(undo::typeof(Base.ensure_rescheduled))
    @ Base ./task.jl:767
  [2] wait()
    @ Base ./task.jl:837
  [3] uv_write(s::Base.TTY, p::Ptr{UInt8}, n::UInt64)
    @ Base ./stream.jl:992
  [4] unsafe_write(s::Base.TTY, p::Ptr{UInt8}, n::UInt64)
    @ Base ./stream.jl:1064
  [5] unsafe_write
    @ ./io.jl:362 [inlined]
  [6] write
    @ ./strings/io.jl:244 [inlined]
  [7] print
    @ ./strings/io.jl:246 [inlined]
  [8] show_unquoted_quote_expr(io::IOContext{Base.TTY}, value::Any, indent::Int64, prec::Int64, quote_level::Int64)
    @ Base ./show.jl:1685
  [9] show(io::Base.TTY, ex::Expr)
    @ Base ./show.jl:1304
 [10] show(x::Expr)
    @ Base ./show.jl:393
 [11] #s54#41
    @ ~/git/Soss.jl/src/primitives/interpret.jl:96 [inlined]
 [12] var"#s54#41"(MC::Any, T::Any, ::Any, _mc::Any, #unused#::Any, _cfg::Any, _ctx::Any)
    @ Soss ./none:0
 [13] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:580
 [14] #rand#46
    @ ~/git/Soss.jl/src/primitives/rand.jl:35 [inlined]
 [15] rand
    @ ~/git/Soss.jl/src/primitives/rand.jl:34 [inlined]
 [16] #rand#44
    @ ~/git/Soss.jl/src/primitives/rand.jl:19 [inlined]
 [17] rand(m::Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{102}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}})
    @ Soss ~/git/Soss.jl/src/primitives/rand.jl:19
 [18] top-level scope
    @ REPL[4]:1
cscherrer commented 2 years ago

I think this error comes from Base.show(xs), but without that things still aren't quite there:

julia> rand(m1(10))
ERROR: UndefVarError: _mc not defined
Stacktrace:
 [1] getproperty
   @ ./Base.jl:35 [inlined]
 [2] macro expansion
   @ ~/git/Soss.jl/src/primitives/interpret.jl:75 [inlined]
 [3] mkfun_call(_mc::Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{102}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}}, #unused#::typeof(Soss.tilde_rand), _cfg::NamedTuple{(:rng,), Tuple{Random._GLOBAL_RNG}}, _ctx::NamedTuple{(), Tuple{}})
   @ Soss ~/git/Soss.jl/src/primitives/interpret.jl:75
 [4] #rand#46
   @ ~/git/Soss.jl/src/primitives/rand.jl:35 [inlined]
 [5] rand
   @ ~/git/Soss.jl/src/primitives/rand.jl:34 [inlined]
 [6] #rand#44
   @ ~/git/Soss.jl/src/primitives/rand.jl:19 [inlined]
 [7] rand(m::Soss.ModelClosure{ASTModel{NamedTuple{(:N,)}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{102}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(:N,), Tuple{Int64}}})
   @ Soss ~/git/Soss.jl/src/primitives/rand.jl:19
 [8] top-level scope
   @ REPL[4]:1

But I do think we're headed in the right direction.

cscherrer commented 2 years ago

I found this to work for the test case:

@gg function mkfun_call(_mc::MC, ::T, _cfg, _ctx) where {MC, T}
    _m = type2model(MC)
    M = getmodule(_m)

    _args = argvalstype(MC)
    _obs = obstype(MC)

    tilde = T.instance
    body = _m.body |> loadvals(_args, _obs)
    body = _interpret(M, body, tilde, _args, _obs)

    q = MacroTools.flatten(quote
            local _retn
            _args = Soss.argvals(_mc)
            _obs = Soss.observations(_mc)
            _cfg = merge(_cfg, (args=_args, obs=_obs))
            $body
            _retn
        end)

    @under_global M q
end

Then I get

julia> rand(m1(10))
(p = 0.880486, x = Bool[1, 1, 0, 1, 0, 0, 0, 0, 0, 0])

julia> @test_opt rand(m1(10))
Test Passed
  Expression: #= REPL[9]:1 =# JET.@test_call analyzer = OptAnalyzer rand(m1(10))

But then I'm not sure how robust it is to using from a different module, etc. There are a few things here that don't yet make sense to me...

First, I think I see what mk_expr is trying to do, but I can't yet get it to work.

I see this idiom come up a lot:

M = ...
@q let M
    ... # body of the function
end

I don't see how this can work. M is defined before the quote, so wouldn't it have to be interpolated into the expression?

And then for testing... Is there a minimal "calling things from other modules" setup you'd use to test for robustness of this sort of thing? Does precompilation complicate that?

thautwarm commented 2 years ago

Just remove the invocation of Base.show in mkfun_call

cscherrer commented 2 years ago

Does that work for you? If I do that I get an error that _mc is not found.

thautwarm commented 2 years ago

Oh, I don't know why but it now raises "_mc" undefined. It shall be caused by some cleanup before PR. Fixing it.

thautwarm commented 2 years ago

Updated. Could you please have a try?

thautwarm commented 2 years ago

It works locally now.

julia> rand(m3(10))
(p = 0.189824138731166, x = Bool[0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

julia> rand(m1(10))
(p = 0.421553717757589, x = Bool[0, 1, 0, 0, 0, 0, 0, 0, 0, 0])

julia> @test_opt rand(m1(10))
Test Passed
  Expression: #= REPL[6]:1 =# JET.@test_call analyzer = JET.OptAnalyzer rand(m1(10))

M is defined before the quote, so wouldn't it have to be interpolated into the expression?

I don't M is explicitly used in the code generated by Soss, but it is to resolve global variables like +, *, etc. We just pass M to the first argument of mk_function/mk_expr, which means the callee module.

If any generated code is explicitly visiting M, please let me know and I make a minor fix to support visiting M.