cscherrer / Soss.jl

Probabilistic programming via source rewriting
https://cscherrer.github.io/Soss.jl/stable/
MIT License
414 stars 30 forks source link

Scoping for nested models #245

Open cscherrer opened 3 years ago

cscherrer commented 3 years ago

This works just fine:

julia> m1 = @model a, b begin
                  p ~ Beta(a, b)
                  x ~ Normal(p, 1.0) |> iid(3)
              end;

julia> m2 = @model begin
                  a ~ Beta(0.5, 0.5)
                  b ~ Beta(1, 0.5)
                  m ~ m1(a = a, b = b)
              end;

julia> t = xform(m2() | (; m = (; x = rand(3))))
TransformVariables.TransformTuple{NamedTuple{(:b, :a, :m), Tuple{TransformVariables.ScaledShiftedLogistic{Float64}, TransformVariables.ScaledShiftedLogistic{Float64}, TransformVariables.TransformTuple{NamedTuple{(:p,), Tuple{TransformVariables.ScaledShiftedLogistic{Float64}}}}}}}((b = as𝕀, a = as𝕀, m = TransformVariables.TransformTuple{NamedTuple{(:p,), Tuple{TransformVariables.ScaledShiftedLogistic{Float64}}}}((p = as𝕀,), 1)), 3)

julia> logdensity(m2() | (; m = (; x = rand(3))), t(randn(3)))
-1.3902285607281852

But in tests I had to do

    @testset "Nested models" begin
        m1 = @model a, b begin
            p ~ Beta(a, b)
            x ~ Normal(p, 1.0) |> iid(3)
        end

        m2 = @model begin
            a ~ Beta(0.5, 0.5)
            b ~ Beta(1, 0.5)
            m ~ m1(a = a, b = b)
        end

        @test_broken let t = xform(m2() | (; m = (; x = rand(3))))
            logdensity(m2() | (; m = (; x = rand(3))), t(randn(3))) isa Float64
        end
    end

Without this, the xform call fails with

Nested models: Error During Test at REPL[24]:1
  Got exception outside of a @test
  UndefVarError: α not defined
  Stacktrace:
    [1] getproperty
      @ ./Base.jl:26 [inlined]
    [2] macro expansion
      @ ~/.julia/packages/GeneralizedGenerated/hIoV7/src/closure_conv.jl:121 [inlined]
    [3] _xform(#unused#::Type{TypeEncoding(Main)}, _m::Model{NamedTuple{(:n, :α, :β), T} where T<:Tuple, TypeEncoding(begin
      p ~ Beta(α, β)
      x ~ Binomial(n, p)
  end), TypeEncoding(Main)}, _args::NamedTuple{(:a, :b), Tuple{Float64, Float64}}, _data::NamedTuple{(:x,), Tuple{Vector{Float64}}})
      @ Soss ~/.julia/packages/GeneralizedGenerated/hIoV7/src/closure_conv.jl:121
    [4] xform(m::Soss.ConditionalModel{NamedTuple{(:n, :α, :β), T} where T<:Tuple, TypeEncoding(begin
      p ~ Beta(α, β)
      x ~ Binomial(n, p)
  end), TypeEncoding(Main), NamedTuple{(:a, :b), Tuple{Float64, Float64}}, NamedTuple{(:x,), Tuple{Vector{Float64}}}})
      @ Soss ~/git/Soss.jl/src/primitives/xform.jl:23
    [5] xform(m::Soss.ConditionalModel{NamedTuple{(:n, :α, :β), T} where T<:Tuple, TypeEncoding(begin
      p ~ Beta(α, β)
      x ~ Binomial(n, p)
  end), TypeEncoding(Main), NamedTuple{(:a, :b), Tuple{Float64, Float64}}, NamedTuple{(), Tuple{}}}, _data::NamedTuple{(:x,), Tuple{Vector{Float64}}})
      @ Soss ~/git/Soss.jl/src/primitives/xform.jl:20
    [6] macro expansion
      @ ~/.julia/packages/GeneralizedGenerated/hIoV7/src/closure_conv.jl:121 [inlined]
    [7] _xform(#unused#::Type{TypeEncoding(Main)}, _m::Model{NamedTuple{(), T} where T<:Tuple, TypeEncoding(begin
      b ~ Beta(1, 0.5)
      a ~ Beta(0.5, 0.5)
      m ~ m1(a = a, b = b)
  end), TypeEncoding(Main)}, _args::NamedTuple{(), Tuple{}}, _data::NamedTuple{(:m,), Tuple{NamedTuple{(:x,), Tuple{Vector{Float64}}}}})
      @ Soss ~/.julia/packages/GeneralizedGenerated/hIoV7/src/closure_conv.jl:121
    [8] xform(m::Soss.ConditionalModel{NamedTuple{(), T} where T<:Tuple, TypeEncoding(begin
      b ~ Beta(1, 0.5)
      a ~ Beta(0.5, 0.5)
      m ~ m1(a = a, b = b)
  end), TypeEncoding(Main), NamedTuple{(), Tuple{}}, NamedTuple{(:m,), Tuple{NamedTuple{(:x,), Tuple{Vector{Float64}}}}}})
      @ Soss ~/git/Soss.jl/src/primitives/xform.jl:23

It seems to be mistakenly finding this model from test/transforms.jl:

m = @model (n,α,β) begin
    p ~ Beta(α, β)
    x ~ Binomial(n, p)
    z ~ Binomial(n, α/(α+β))
end

@thautwarm I'm not sure what's going on here, looks like a scoping issue. Any ideas?

thautwarm commented 3 years ago

This is due to Test package's processing.

@testset creates a new local scope, m1 = xxx creates a local binding, but m1 in model m2 is trying to read global variables from current module.

cscherrer commented 3 years ago

Do you think this is a concern, or just a quirk of the test environment? If it's just test environment weirdness, we could maybe just change some names to fix it.

thautwarm commented 3 years ago

Do you think this is a concern, or just a quirk of the test environment?

It depends.

There is no bug, but we should consider the support of the local models.

Let me say something related first:

This issue is actually related to the creation of @model in a local scope.

 m2 = @model begin
    a ~ Beta(0.5, 0.5)
    b ~ Beta(1, 0.5)
    m ~ m1(a = a, b = b)
 end

m1 above is referencing a global one, but I wonder if you'd like it to reference a local one?

Automatically deciding m1 from the caller scope is difficult because you just store ASTs of the model body.

I'd propose an elegant tradeoff, that we can use $m1 to reference the local m1.

 m2 = @model begin
    a ~ Beta(0.5, 0.5)
    b ~ Beta(1, 0.5)
    m ~ $m1(a = a, b = b)
 end
cscherrer commented 3 years ago

Thanks @thautwarm . Just to be sure I understand, does this issue affect all references to local values, or just those to other models? And is the problem that it can't find the local m1 at all, or just that it doesn't know which should get higher precedence?

Ideally we'd have the same scoping rules as Julia. When that's not possible, would we be able to show a one-time warning? Maybe > "WARNING: Calling global m1. Name exists in both global and local scopes, use $m1 to call local value ?

Or maybe knowing to warn in this way is just as difficult as the original problem?

thautwarm commented 3 years ago

The problem is, when using @model locally, we cannot know if m1 is a local variable or a global variable.


function f(m1)
    @model ... begin
         ...
         a ~ m1(...)
    end
end

The macrocall @model ... begin ... end cannot know the outside stuffs, hence it's impossible to analyse.

The only way to reference local variables is to use esc(:m1) in the return of @model.

But now we keeps the ASTs of statements, this means if we cannot use esc.

thautwarm commented 3 years ago

A possible solution is to make @model generate an esc-ed expression that constructs a Model, and the return of @model macro should be an expression where symbols are escaped. But in this case, a, b in @model a, b begin ... end also need to be escaped.

thautwarm commented 3 years ago

This is very similar to the following case:

When you want to get the AST :(m + 1), you cannot easily resolve m from the local scope:

function f(m)
    :(m + a)
end
function resolve(ex)
    ...
end

# every symbol in Expr.args have to be inserted,
# because we cannot distinguish them from each other
@assert resolve(:(m + a)) == (:( :($+($m, $a)))) 

macro resolve(ex)
   esc(resolve(ex))
end

function f(m)
   @resolve m + a  # m from local, a from global
end

Things will be more difficult when it comes to parametric model. It's doable but difficult.

In this sense, it's more reasonable to write

function f(m)
    :($m + a)   # distinguish the local and global manually
end
cscherrer commented 3 years ago

Thinking some more about this @thautwarm , maybe this example makes the current behavior more clear:

julia> using Soss

julia> function f(x)
           m = @model begin
               y ~ Normal(x,1)
           end
       end;

julia> f(2.0)
@model begin
        y ~ Normal(x, 1)
    end

julia> rand(f(2.0))
ERROR: UndefVarError: x not defined
Stacktrace:
 [1] getproperty
   @ ./Base.jl:26 [inlined]
 [2] macro expansion
   @ ~/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:116 [inlined]
 [3] (::ggfunc-function)(pargs::Random._GLOBAL_RNG; pkwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
   @ GeneralizedGenerated.NGG ~/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83
 [4] RuntimeFn
   @ ~/.julia/packages/GeneralizedGenerated/PV9u7/src/ngg/runtime_fns.jl:83 [inlined]
 [5] rand
   @ ~/git/Soss.jl/src/primitives/rand.jl:28 [inlined]
 [6] rand(m::Model{NamedTuple{(), T} where T<:Tuple, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{23}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}})
   @ Soss ~/git/Soss.jl/src/primitives/rand.jl:31
 [7] top-level scope
   @ REPL[7]:1

julia> x = 200.0
200.0

julia> rand(f(2.0))
(y = 199.89822240734952,)

julia> x = 100.0
100.0

julia> rand(f(2.0))
(y = 100.36657113705026,)

At least for me, this makes things more clear :)

cscherrer commented 3 years ago

Currently a model is entirely static - all information is contained in the AST. Would it be possible/sensible to have a dynamic component as well? Maybe values in local scope could be stored in a closure or named tuple, then a "new model" could be one of these paired with an AST.

I have no idea if this is actually makes sense, or if there might be a performance penalty.

cscherrer commented 3 years ago

Maybe this can help? https://github.com/MasonProtter/StaticModules.jl

thautwarm commented 3 years ago

As we are manipulating ASTs, using $ will be a direct approach.

Otherwise, we need a macro to tranform :(Normal(x, 1)) to :($Normal($x, 1)).

In StaticModels.jl, Mason is using JuliaVariables.jl to resolve the scope of names, so the aforementioned transformation is easy.

However, JuliaVariables.jl will normalize the ASTs, such as eliminating do expressions, which I think might not be good for Soss..

cscherrer commented 3 years ago

Thanks @thautwarm . You know this area much better than I do, but I'm not sure I'm being clear about the idea.

Suppose we were able to "freeze" the variables in local scope as a named tuple. We could carry this around with the AST. We wouldn't have the values at the type level, but we'd still have the names and the type of each value. So when we do codegen, we'd know which values need to become named tuple lookups.

But... maybe this "freezing local scope" is difficult or impossible? I have no idea on that.

I really like the idea of being able to statically splice in values at model definition time. But I hope this can be a "once in a while" thing, so we don't end up with most models having lots of $s.

Also, I think it's completely fine to have "standard" workflows to suggest for users to make things easier for them. So that's another option.

The one situation I'd really like to avoid is for models defined in local scope to be very error-prone, like being very unclear which variables are referenced. The example that started this had me confused for a while, so I think it's very likely a casual usual could get stuck.

cscherrer commented 3 years ago

Hi @thautwarm , just checking back on this. I guess the important thing is to have some way to resolve local variables. Would your suggestion for $ splicing only be needed for variables in local scope that are currently resolved incorrectly?

cscherrer commented 3 years ago

BTW just a note, I remembered we can still do

@testset "Nested models" begin
    nested = @model a, b begin
        p ~ Beta(a, b)
        x ~ Normal(p, 1.0) |> iid(3)
    end

    m = @model sub begin
        a ~ Beta(0.5, 0.5)
        b ~ Beta(1, 0.5)
        m ~ sub(a = a, b = b)
    end

    outer = m(sub=nested)
    t = xform(outer | (; m = (; x = rand(3))))

    @test logdensity(outer | (; m = (; x = rand(3))), t(randn(3))) isa Float64        
end
thautwarm commented 3 years ago

just checking back on this. I guess the important thing is to have some way to resolve local variables. Would your suggestion for $ splicing only be needed for variables in local scope that are currently resolved incorrectly?

Yes, exactly!

thautwarm commented 3 years ago

BTW just a note, I remembered we can still do

@testset "Nested models" begin
    nested = @model a, b begin
        p ~ Beta(a, b)
        x ~ Normal(p, 1.0) |> iid(3)
    end

    m = @model sub begin
        a ~ Beta(0.5, 0.5)
        b ~ Beta(1, 0.5)
        m ~ sub(a = a, b = b)
    end

    outer = m(sub=nested)
    t = xform(outer | (; m = (; x = rand(3))))

    @test logdensity(outer | (; m = (; x = rand(3))), t(randn(3))) isa Float64        
end

This way disables implicitly passing local variables to models, which can be an approach.

We could ask users if disabling free variables in models is okay?

cscherrer commented 3 years ago

We could ask users if disabling free variables in models is okay?

I think we need that, otherwise IIUC they wouldn't be able to bring in functions or distributions, or it would become very awkward.

Maybe we say

I'm still wondering if StaticModules can help, but maybe that's an orthogonal question. As it correct to say that for the $ approach what we need is an addition to the @model macro to allow for the interpolation?

thautwarm commented 3 years ago

I'm still wondering if StaticModules can help, but maybe that's an orthogonal question.

I doubt whether StaticModules can help here. But you could ask Mason, maybe I didn't get it correctly.

As it correct to say that for the $ approach what we need is an addition to the @model

Yes.


Besides, very sorry that there is another question I didn't answer.

Currently a model is entirely static - all information is contained in the AST. Would it be possible/sensible to have a dynamic component as well?

Yes.

First-class models look good to me, and this can be a reason of using GG here.

body  = quote
   a ~ $local_var(...)
   ...
end
@model arg $body

It seems we just need to convert all Expr(:$, a) inside the second argument of @model into esc(a)? I remember there are lots of similar macros working in this approach, like @btime etc.

cscherrer commented 3 years ago

https://discourse.julialang.org/t/interpolation-in-macro-calls/25530 :)

This is such a common thing to want to do, I'm kind of surprised there's not something like a general-use @splice allowing you to declare

@splice macro foo(...) 
    ....
end

making @foo work as defined but adding splice functionality. Maybe that's not possible?

cscherrer commented 3 years ago

A discussion today with @torfjelde helped me realized that our current name resolution has some bugs when local variables are present. Here are two examples:

julia> using Soss

julia> m1 = @model begin
           a ~ @model begin
               p ~ Uniform()
               x ~ Bernoulli(p)
               return x
           end
           x ~ Normal(μ=a)
       end;

julia> rand(m1())
(x = 0.8830268940583486, a = false)

julia> m2 = @model begin
           a ~ For(3) do x Normal(μ=x) end
           x ~ Normal(μ=sum(a))
       end;

julia> rand(m2())
(x = 7.4847701873845995, a = [1.3278580687890864, 1.6612251273704697, 3.9548775107120484])

julia> digraph(m1).N
Dict{Symbol, Set{Symbol}} with 2 entries:
  :a => Set([:x])
  :x => Set([:a])

julia> digraph(m2).N
Dict{Symbol, Set{Symbol}} with 2 entries:
  :a => Set([:x])
  :x => Set([:a])

I think to fix this we need to use JuliaVariables on the AST before processing it as a DAG. @thautwarm does that sound right to you? The information we need out of it is just the dependency structure between top-level expressions.

thautwarm commented 3 years ago

Introducing JuliaVariables.jl is okay, but so far JuliaVariables.jl requires transforming all f(...) do args; expr end to f(args -> expr, ...), which is the reason why I didn't suggest this initially. I can support a mode for JuliaVariables.jl to directly analysing Expr(:do instead of normalizing the AST.

cscherrer commented 3 years ago

For DAG models, I think we really only need the dependencies. So for example,

julia> m = @model begin
           a ~ For(3) do x Normal(μ=x) end
           x ~ Normal(μ=sum(a))
       end;

julia> m.dists.a  |>  MacroTools.prettify |> simplify_ex |>  solve_from_local  |> unwrap_scoped
:((@global For)(function (x,)
          (@global Normal)(μ = @local x)
      end, 3))

This tells us that a depends on For and Normal. But neither of those are Soss variable names for this model. So in the generated code, a can be first. OTOH,

julia> m.dists.x  |>  MacroTools.prettify |> simplify_ex |>  solve_from_local  |> unwrap_scoped
:((@global Normal)(μ = (@global sum)(@global a)))

tells us that x depends on Normal, sum, and a. Of these a is the only one that's a variable name for this model.

Together, these tell us that the graph must be a → x.

For DAG models, this is enough. We don't need to hang on to the annotated result.

I'm starting to think about what this might look like for more flexible models where there can be control flow. Then we can have an ASTModel that's more flexible but makes fewer guarantees, and a more limited DAGModel that can do more in terms of the static analysis.

For example, I think something like this can be very useful: https://github.com/cscherrer/Soss.jl/blob/cs-astmodels/src/primitives/interpret.jl

For an ASTModel it might turn out to be more important to hang on to the annotated code.