probcomp / Gen.jl

A general-purpose probabilistic programming system with programmable inference
https://gen.dev
Apache License 2.0
1.79k stars 160 forks source link

Dispatch combinator to handle multiple dispatch in generative functions #255

Open femtomc opened 4 years ago

femtomc commented 4 years ago

As per the Gen slack - one thing which may be slightly confusing to people who are initially using Gen is that generative functions are structs. This implies that programmers must avoid normal Julia multiple dispatch idioms.

This is exemplified by the following code chunk which "silently" fails because types are not checked in generate and the generative function definition overwrites the previous. Thus, the support is completely different than what might be expected, likely unbeknownst to the programmer.

module MWEDispatchIssues

using Gen

@gen function foo(y::Int)
    if y < 5
        x = @trace(normal(0.0, 1.0), :x)
    else
        x = @trace(categorical([0.5, 0.5]), :x)
    end
    return x
end

@gen function foo(y::Float64)
    if y > 10
        x = @trace(normal(0.0, 10.0), :x)
        z = @trace(normal(0.0, 10.0), :z)
    else
        x = @trace(normal(0.0, 1.0), :x)
        z = @trace(normal(0.0, 1.0), :z)
    end
    return x, y
end

tr, _ = generate(foo, (5,))
display(get_choices(tr))

end

To solve this issues, there have been many suggestions on the Slack, which I am posting here below.

From George Matheos:

Alex Lew I agree that the conceptually neatest way to do the dispatch is as a combinator! This might be the right route implementation-wise too, since it’s easy to understand, though I’m wondering if we might actually just be able to change the implementation of the static and dynamic DSLs slightly to implicitly implement this sort of combinator. I think what keorn is saying about relying on regular Julia method dispatch could probably work pretty well. I haven’t ironed out all the details, but I’m thinking: For the static DSL, the gfi methods are all custom generated methods, meaning we have a separate method for each concrete argument type. We could augment the trace type of the static DSL to include the type of the arguments, and then declare a separate generated GFI method for each of the declared static DSL methods, so that depending on the trace/argument types, julia will automatically use the code for the proper user-declared method.
For this, we’d need to have the gen (static) macro log some metadata to allow us to do this method dispatch later, but this should be pretty doable, I think. In terms of “safety”, what we want to avoid is users trying to update a trace generated by one generative method by changing its arguments so it would have been generated by a different generative method. In this implementation, this could never happen, since we would never declare an update method where the args are for a different generative method than the trace. (If we want to make it possible to do this type of update without an error, and simply return a weight of -Inf most of the time, we’d need to create some additional methods for update ; this is probably the right approach rather than having an error.) For the dynamic DSL, I am not totally clear on how to do it, but I’m hoping there is some way to mostly take advantage of the automatic Julia dispatch on the defined function. Most GFI methods involve calling exec on an underlying Julia function; can we just use standard multiple dispatch on that? Maybe; we would also need change how the dynamic macro compiles things so that we can have multiple methods on the underlying function in the dynamic gen function; my guess is that we could figure this out, though. In terms of “safety”, I think that we probably don’t need to worry too much if a user updates the argument types; we can view a julia function as always beginning with a “switch” statement on argument types to select the method, which the dynamic DSL should be able to support by default. (edited)

From Alex Lew:

George Matheos Are you sure, re update? What if the arg types change, because a random choice in my caller changed? (eg, maybe my caller generates a tree and passes it in, and it changes from a LeafNode to an InternalNode.) Conceptually, a changing arg type would be no different from a changing argument, yeah? If the user’s call to update would lead to an impossible state, we throw an error (as we already have to do). We would also need to consider what happens when some declarations are static and others are dynamic. I like the combinator approach because it could isolate the logic for all this in (mostly) one place in the codebase, and allow us to reason clearly about all the requirements to make sure nothing else breaks. But open to considering other approaches too, particularly if performance is an issue. I agree w keorn that we could still use Julia’s existing machinery to decide how to dispatch

From George Matheos:

I see what you mean Alex, if we want to support dispatch where some methods are static and some are dynamic the combinator makes a lot more sense. I bet performance-wise we could also get it to be almost as good as without a combinator if we take a “static dsl” approach to the combinator and compile separate GFI methods for each declaration for a gen function


From the previous discussion, it seems like the core proposal is to have a "dispatch combinator" which also has a static DSL-like compilation for GFI which is specialized to types. The combinator would apply to both static and dynamic languages.

Edit: I'm opening this issue because I'm interested in supporting this, as it would simplify a lot of code I've been writing in Gen recently.

Edit2: fixed all the @ references.

femtomc commented 4 years ago

@georgematheos

keorn commented 4 years ago

To encourage progress here, I would like to put a ~500 $ bounty for the following issue: https://gitlab.com/plantingspace/broadleaf/-/issues/1

The above combinator stategy is potentially one strategy for implementation of generative methods.

You can read more detailed instructions about the terms here.

georgematheos commented 4 years ago

@keorn thanks for the bounty! I think we can do this pretty quickly, but the way I have in mind will only work under the following restriction:

for any "generative function with methods", either:

  1. The different methods all have disjoint address spaces or 2.The user will never call Gen's update function to switch between methods

To illustrate why it is an issue if neither of these conditions are satisfied, here's an example. Say we have the following setup:

@gen function foo(x::Int)
    y ~ normal(0, 1)
    z ~ normal(y, 0.05)
    return z
end
@gen function foo(x::Float64)
    y ~ normal(0, 1)
    return y
end
tr = simulate(foo, (1,)) # calls the first method, with `z`

If we call

update(tr, (1.0,), (UnknownChange(),), EmptyChoiceMap())

this should trigger a switch from the first method (with z) to the second method (without z). According to the semantics of the update function, we want the value of y from the original trace to end up in the second function's trace. This would mean when we generate a trace for the foo(::Float64) method, we need to know to pass in y as a constraint, but not z. (If we constrain z but there is no z address in the method we are switching to, as things are currently implemented, this will throw an error. Perhaps it makes sense to have a way to loosen this restriction. Maybe @marcoct has some thoughts on this.)

To know which addresses to constrain and which not to constrain, we need some way to "peek" into the generative function. However, the dynamic generative function DSL (invoked by @gen function) is built to be entirely black-box: we can't "look inside". (For static DSL functions [built using @gen (static) function], it should be possible to see what addresses need to be used.) As a result, I don't think there's an easy way to automatically handle this sort of update.

@keorn if you want this in the short-term, I can implement it using this restriction. (So if you just want dispatch but don't intend to be switching methods using update, this should be sufficient.) I'm not sure how important it is to handle this method-switching (since we can always use a regular if statement to switch generative methods if needed). If we want to handle method-switching, I think it will take some more thinking.

keorn commented 4 years ago

Great! Yes, in my view it is quite reasonable to expect that different methods have disjoint address spaces. While a neat program may have functions with the same name doing roughly the same things, in Julia one can only determine the behaviour of the function given its name and input types (also type hierarchy and other methods).

belledon commented 4 years ago

Hi all!

Not sure if this is the right way to bump this topic but this would be a great feature IMO.

Also happy to contribute if there are todos

femtomc commented 4 years ago

@georgematheos @marcoct why is that constraint restriction in place?

The only thing I can think of is because you may be deleting addresses in the constraints as you visit them, and then perform a runtime check that you’ve visited all constraints. Is this required specifically for the semantics of update - or is it only used for checking proposal generated constraints in inference?

I completely understand why this is needed when proposing from a Gen func, then scoring the proposal. In that case, the check makes sense - because it’s the AC requirement. Maybe if you lift the restriction for the generic version of update, you can easily cover both branches and be more general about this combinator?

femtomc commented 4 years ago

One way to do this would be to specialize the visitor structure which tracks what addresses you’ve been to, and construct a variant for update which also specifically tracks when you visit a choice which has been proposed. This, along with the normal runtime checking that you don’t visit the same address twice at the same level of the call stack.

JackKenney commented 2 years ago

I'm not sure if this issue is something people are still thinking about - and don't mean to necrobump - but I'm working on a package that uses Gen heavily and supporting this would make the design patterns a lot cleaner.

For context, I have one overarching model that I would love to be able to use multiple dispatch to support several types and sub-structure specifications, but for now I have to define all the potential models separately and figure out how to dispatch them manually.

ztangent commented 2 years ago

We unfortunately don't have people working on this feature right now, but as an alternative, you could use the following pattern that I ended using for Plinf.jl:

Here's the (abstract) definition of the getter function I used for the abstract Planner type I wanted to dispatch on: https://github.com/ztangent/Plinf.jl/blob/2a8f86eb91095a37b9c2b034dc37529abfe945d7/src/planners/planners.jl#L26-L27

Here's the generative function that calls the getter function get_call internally: https://github.com/ztangent/Plinf.jl/blob/2a8f86eb91095a37b9c2b034dc37529abfe945d7/src/planners/planners.jl#L47-L53

And here's an example concrete definition of get_call for a concrete subtype of Planner: https://github.com/ztangent/Plinf.jl/blob/2a8f86eb91095a37b9c2b034dc37529abfe945d7/src/planners/astar.jl#L18-L22

Hope this helps address your use case! You can even automate away having to manually define your equivalent of get_call using a macro , if you wanted.

JackKenney commented 2 years ago

Thanks for your quick response! I'll check this pattern out; it looks promising.