cscherrer / Soss.jl

Probabilistic programming via source rewriting
https://cscherrer.github.io/Soss.jl/stable/
MIT License
414 stars 30 forks source link

DAG construction with local variables #253

Closed cscherrer closed 3 years ago

cscherrer commented 3 years ago

I had brought this up in https://github.com/cscherrer/Soss.jl/issues/245, but I think it's really a different issue, so I'm splitting it.

The problem is in models like

m = @model begin
    a ~ For(3) do x Normal(μ=x) end
    x ~ Normal(μ=sum(a))
end

Soss thinks the graph is

julia> digraph(m).N
Dict{Symbol, Set{Symbol}} with 2 entries:
  :a => Set([:x])
  :x => Set([:a])

But this is wrong! The x referred to in a ~ ... is a local variable.

I think JuliaVariables.jl ought to be able to help with this. The idea is

  1. Get the rhs for each statement
  2. solve each rhs, which adds @local and @global annotations
  3. Extract the set of @global ones, which should give the true dependencies

Here's my first attempt:

julia> using JuliaVariables

julia> using MacroTools: prettify

julia> m.dists.a |> solve_from_local |>  prettify |> unwrap_scoped
:((@global For)(3) do @global x
      @global Normal
      $(Expr(:kw, :μ, @global x))
  end)

julia> m.dists.x |> solve_from_local |>  prettify |> unwrap_scoped
:((@global Normal)(μ = (@global sum)(@global a)))

I'm not sure what's going on in that @global x, since x is clearly a local variable. My guess is that the do notation is throwing off the solver and needs to be rewritten.

thautwarm commented 3 years ago

call simplify_ex once before calling solve_xxx. JuliaVariables does not handle Expr(:(->), but Expr(:function

cscherrer commented 3 years ago

Got it, thank you @thautwarm !

julia> unwrap_scoped(ex) =
                  @match ex begin
                      Expr(:scoped, _, a) => unwrap_scoped(a)
                      Expr(head, args...) => Expr(head, map(unwrap_scoped, args)...)
                      a => a
                  end
unwrap_scoped (generic function with 1 method)

julia> m = @model begin
           a ~ For(3) do x Normal(μ=x) end
           x ~ Normal(μ=sum(a))
       end;

julia> m.dists.a  |>  MacroTools.prettify |> simplify_ex |>  solve_from_local  |> unwrap_scoped
:((@global For)(function (x,)
          (@global Normal)(μ = @local x)
      end, 3))