probcomp / Gen.jl

A general-purpose probabilistic programming system with programmable inference
https://gen.dev
Apache License 2.0
1.8k stars 160 forks source link

SML explores all branches in short-circuit evaluation #264

Open femtomc opened 4 years ago

femtomc commented 4 years ago

The following chunk of code passes the static syntax checks:

@gen (static) function thinking!(t::Int, 
                        is::InteractionState
                       )::InteractionState

    tree, env = is.tree, is.env

    # Pre-updates.
    tree, env = update_actions!(tree, env) # deterministic
    tree = update_STCC!(tree, env) # non-deterministic
    tree = update_conflicts!(tree) # deterministic

    # Step selection.
    step = @trace(step_arbiter(tree), :step_arbiter) # non-deterministic

    # Execution.
    c1 = type(step) == MentalAction && begin
        tree, env = @trace(execute_MA!(step, tree, env), :execute_MA) 
        return InteractionState(tree, env)
    end || false
    c2 = type(step) == Goal && begin
        tree, env = @trace(execute_G!(step, tree, env), :execute_G) 
        return InteractionState(tree, env)
    end || false
    c3 = type(step) == Action && begin
        tree, env = @trace(execute_A!(step, tree, env), :execute_A) 
        return InteractionState(tree, env)
    end || false

    return InteractionState(tree, env)
end

However, the semantics are totally different than what the code implies...

 [4] generate(::Main.PathPlanningAgent.GenerativeHap.var"##StaticGenFunction_execute_A!#732", ::Tuple{Main.PathPlanningAgent.GenerativeHap.Step{Main.PathPlanningAgent.GenerativeHap.Goal,Main.PathPlanningAgent.GenerativeHap.Persistent,Main.PathPlanningAgent.GenerativeHap.IgnoreFailure,Main.PathPlanningAgent.GenerativeHap.EffectOnly,NamedTuple{(:name, :specificity, :preconditions, :expander),Tuple{Symbol,Float64,Array{Main.PathPlanningAgent.GenerativeHap.Actuator,1},Main.PathPlanningAgent.var"#1#8"{Symbol,DataType,Float64}}}},FunctionalCollections.PersistentHashMap{Symbol,Any},FunctionalCollections.PersistentHashMap{Symbol,Any}}, ::Gen.EmptyChoiceMap) at /home/mccoy/.julia/dev/Gen/src/static_ir/generate.jl:114

This is the address for the execute_A! call - this branch should not be reached when executing this code.


An MWE outside of my lib is the following:

module MWEShortCircuits

using Gen

@gen function baz(b::Bool)
    if b
        println("Recording...")
    end
    return false
end

@gen (static) function foo()
    x = true
    y = x && begin
        q = @trace(baz(x), :q1)
    end || true

    z = !x && begin
        q = @trace(baz(x), :q2)
        return true
    end || false
    return (x, y, z)
end

@gen function foo2()
    x = true
    y = x && begin
        q = @trace(baz(x), :q1)
    end || true

    z = !x && begin
        q = @trace(baz(x), :q2)
        return true
    end || false
    return (x, y, z)
end

@Gen.load_generated_functions()

tr, _ = generate(foo, ())
println(get_choices(tr))
println(get_retval(tr))
tr, _ = generate(foo2, ())
println(get_choices(tr))
println(get_retval(tr))
end

produces

Recording...
Recording...
Gen.EmptyChoiceMap()
(true, true, false)
Recording...
Gen.EmptyChoiceMap()
(true, true, false)

which implies that the trace is somehow empty, but both calls to baz were reached through the @trace calls. And the semantics is different between the languages!

alex-lew commented 4 years ago

@femtomc, it looks like some recent relaxations on the syntactic restrictions of the static language has caused us to let some things through that should be blocked. Thanks for bringing this to our attention, and I'm sorry this bit you!

@ztangent Can you take a look at this? My intuition is that anything that allows for conditional evaluation (ternary expressions, if statements in begin ... end blocks, short-circuiting && and ||, etc.) should not be allowed to contain @trace expressions. I'm also a little surprised we allow the same variable name (tree) multiple times -- was that something you added support for? (I forget. If so, great!)

@femtomc It might help to know that the semantic restrictions on the Static language are that:

These should both be better documented, and it's definitely a bug that the static language allows the syntax you've pasted above :-) Thanks for the example.

In general, in the static language, the way to handle conditionals is to call out to a dynamic DSL function that uses the conditional logic. Then the high-level structure of the program can still be static, enabling efficient incremental computation when only some addresses are updated, while leaving you the flexibility to use stochastic control flow further down in the call stack. @fsaad has also explored adding a Cond combinator, which I still think is a good idea (and could be a good warm-up project for someone looking to contribute to Gen!).

ztangent commented 4 years ago

@alex-lew Yes, we added support for reusing variable names!

Mutation indeed is not allowed, but that currently can't be checked for at parse time (and I don't think it ever can be? At best we can detect functions which end with ! as a heuristic).

Conditional evaluation is also not allowed or properly supported, as this example shows. Currently the parser fails to block them -- it only blocks top-level control flow constructs, such as if, then, else -- but parses any Julia expression that is on the RHS of an assignment. This includes any begin ... end blocks, as well as any nested @trace expressions, which is why the code passes all the static syntax checks.

The reason why foo explores all branches in the MWE shown above is that the static parser constructs an intermediate representation where there is a new node for every @trace expression it encounters. And so when the generative function is run, all of those @trace calls get run -- short-circuit evaluation is not respected, because the control flow structure induced by short-circuit evaluation isn't compiled into intermediate representation.

Note also that the assignment q = @trace(baz(x), :q1) will be ignored, because it occurs within a begin ... end block, and not at the top-level of the function body. So it won't be possible to access the value of q afterwards.

It's not entirely clear to me yet how to block this kind of code at parse-time without writing too many special-cases. At least some kinds of conditional evaluation are valid and useful, e.g., you might want to write

@gen (static) function flip_coin()
    face = @trace(bernoulli(0.5), :face) ? "heads" : "tails"
    return face
end

And this will work as expected, because the @trace call is within the condition, and is expected by the user to always execute. So it seems like we don't want to block this use case of ternary expressions. But if the user writes an @trace call within the body of the ternary expression instead, then this will cause unexpected behavior vis a vis short-circuit evaluation.

ztangent commented 4 years ago

I'll just add that another reason why we might not want to block constructs like ternary expressions, begin ... end statements, and short-circuiting &&, is that these don't cause any issues if they only contain ordinary Julia code within them. Apart from variable assignment, all of those constructs will work as expected (and in fact, variable assignment that is completely local to, e.g., a begin end block, should work fine as well, it just won't work if you expect it modify some variable from the outer scope).

So as @alex-lew mentioned, we really only want to block parsing for the cases where @trace expressions end up within those constructs. But that's going to be a little difficult using MacroTools.postwalk alone. The issue is that detecting those cases is no longer context-free, in the sense that we can't just walk the AST from bottom-up and pattern match to catch disallowed constructs -- we'd need to detect that there is an @trace call, remember that fact, and then walk further up the tree until we find that the @trace call happens to be nested within one of the disallowed expressions.

femtomc commented 4 years ago

Just a comment as I digest the other points here:

You actually can detect mutation if you drop to lowered code or IR, see e.g.

module MutationDetector

using IRTools
using MacroTools
using InteractiveUtils

mutable struct Baz
    x::Float64
end

foo(b::Baz) = Baz.x += 1.0

b = Baz(10.0)
ir = @code_ir foo(b)
println(ir)

function pass(ir)
    ir = copy(ir)
    for (v, st) in ir
        MacroTools.postwalk(st) do el
            println(el)
            el isa GlobalRef && el.name == :setproperty! && error("No mutation!")
        end
    end
end

pass(ir)

end #module

One other comment: my code actually works fine in the DML, I was just curious if I could express it in the SML.

Edit: I am also a bit worried about the mutation, but the support for my choices is always categorical. I think this means that, in practice, my log ratios log f - log g will just threshold to -Inf if a proposed choice doesn't match the support of f (because g will always be small, but positive, across possible choices).

femtomc commented 4 years ago

One other interesting tidbit - short-circuit evaluation and conditional branching de-sugar to branch unless

module ShortCircuitLowered

using IRTools

function foo1(x::Bool)
    x && begin
        y = 10
        return y
    end
    return 15
end

function foo2(x::Bool)
    if x
        y = 10
        return y
    else
        return 15
    end
end

ir_1 = @code_ir foo1(true)
ir_2 = @code_ir foo2(true)

println("foo1: $(ir_1)\n")
println("foo2: $(ir_2)\n")

end

produces:

foo1: 
1: (%1, %2)
  br 4 unless %2
2:
  return 10
3:
  br 4
4:
  return 15

foo2: 
1: (%1, %2)
  br 3 unless %2
2:
  return 10
3:
  return 15

so it is possible to restrict to one form. The issue is just what level the parser is operating at. I really don't know enough about the SML to know if this is a viable suggestion - but some food for thought.

Actually, this is sort of cool, let me try some experiments to introspect Gen code at the IR level. In particular, I want to see what the static parser sees.

alex-lew commented 4 years ago

Thanks — great points all!

@ztangent Does it make sense to have different parsing contexts, one for “top-level” and one for “only Julia code allowed”? In the only Julia code allowed setting, we could detect

and throw an error if either occurred. (Not saying you need to implement this, just wondering aloud if it would work.)

The possibility of implementing the static parser at a lower level of abstraction (e.g. Julia’s IR) is also intriguing longer-term :-)