probcomp / Gen.jl

A general-purpose probabilistic programming system with programmable inference
https://gen.dev
Apache License 2.0
1.79k stars 160 forks source link

Static codegen for dynamic DSL models #275

Open femtomc opened 4 years ago

femtomc commented 4 years ago

I have a few ideas about how to implement the static codegen dynamically which I wanted to chat about. The eventual goal of this work would be to obviate the need for separate static/dynamic modeling languages, and instead have Gen (or other systems which operate on trace-based principles) automatically detect what portions of the code are amenable to analysis. In this issue, I'll be posting references to my own WIP system, but the techniques should be transferable to Gen.


As far as I understand, the reason why the static lang is restricted is so that the DAG can be constructed explicitly, and the dependencies (and Markov blanket) can be computed exactly. This allows static codegen to specialize the method body for inference, where things are not recomputed if they do not need to be. The approach I have in mind is more dynamic - it would encompass the static approach when the lowered method body admits a flow analysis e.g. https://github.com/femtomc/Jaynes.jl/blob/e27689666e2aeee440671526eeb9bbd7c0d63081/src/core/static.jl#L69 but would always fall back on interpretation of the original function.

The basic idea would be to prototype an interface which will dynamically generate a compiler pass to cut out all the code outside of the Markov blanket (maybe this is too permissive, I need to read the src a bit more) of a site, as well as all code outside of the critical path of an arg with argdiffs. This is exactly what the static IR does for codegen - but this pass would be applied to any method call (even dynamic ones) because the flow analysis should be able to be extended to perform across basic blocks. This seems advantageous to me, because it doesn't require constructing new IR nodes.

For starters, I would try this for the equivalent of the static IR (no control flow) and then widen the pass. It's easy to check if a function satisfies the static spec (for example) - e.g. https://github.com/femtomc/Jaynes.jl/blob/e27689666e2aeee440671526eeb9bbd7c0d63081/src/core/static.jl#L39 so a first implementation could leave functions which fail the check alone. To widen the pass, you could start by using a partial evaluation tracing approach e.g. https://github.com/MikeInnes/Mjolnir.jl to attempt to unroll loops, the result of which would still satisfy the static spec, including removable control flow.

One easy point of access for this sort of tool is in the call barriers - e.g. https://github.com/femtomc/Jaynes.jl/blob/e27689666e2aeee440671526eeb9bbd7c0d63081/src/core/contexts.jl#L236 here for a Cassette-based point of access. Gen also possesses these as you traverse across the call interface. I'm imagining that the pass can be applied at this point of access in Gen.

One thing I would need to investigate for Gen integration is the performance feasibility of a lowered code lookup for the function inside the GenerativeFunction. This would only occur for update and regenerate execution contexts. However, if the specialized result is cached - the number of look ups would asymptotically decline as you perform inference. In static codegen, does the generated code for a particular set of argdiffs and constraint choicemaps get cached? (i.e. so you don't regenerate the method body when you get the same information). I would need to figure out how to keep the transformed code cached so that the pass isn't required after you've performed it before.


There are a few advantages to this implementation:

  1. The fallback is always dynamic interpretation, so the pass (if allowed) could be applied to arbitary GenerativeFunction code - where it would determine if method body is amenable to transformation.
  2. This should interact well with the combinators. As far as my understanding goes, these calls already provide dependency information to the inference engine - so the pass can ignore them (or potentially utilize them, I haven't thought deeply about this beyond the Map combinator).

The big thing, I think, is that this makes the system more user friendly - especially if we can push off annotations to compiler analysis.

While this is a large issue (and a long term one), I just wanted to open up a space for discussion - e.g. I've been thinking about this from a naive perspective about the static language, and it would certainly be helpful to have more expertise critique these ideas.

alex-lew commented 4 years ago

Thanks for posting this, @femtomc! I hadn't seen Mjolnir -- that looks super cool. Will read this a bit more carefully before saying more.

georgematheos commented 4 years ago

This is a really interesting idea! Thank you @femtomc for posting this! I think that it's possible this sort of approach for static compilation could pan out. What I'm understanding is that one way to look at this is:

Currently static updates are compiled via a big Julia file which explicitly contains code to calculate Markov blankets, and uses macros to write Julia source code that encodes the minimum amount of work needed to do a valid update, given the Markov blanket structure of a given model.

But it might be possible to: Instead, have all the code be written in a dynamic-DSL, and don't have any code in the Gen repo to write update functions using macros. Instead, write the code in such a way that Julia's type-inference will automatically compile our dynamic updating code into specialized update functions which understand Markov blankets.

Let me illustrate how I'm envisioning this could work:

Say we have the following generative function:

@gen function foo(x)
    a ~ bar1(x)
    b ~ bar2(a)
    return b
end

If we run update(foo_trace, (x,), (NoChange(),), StaticChoiceMap(:b => :c => 5)), the dynamic DSL update code would specify to do the following:

  1. Update :a by calling update(bar1_tr, (x,), (NoChange(),), EmptyChoiceMap(). This will result in a dispatch to the method update(tr::Trace, ::Tuple, ::NTuple{n, NoChange} where {n}, ::EmptyChoiceMap) = (tr, 0., NoChange(), EmptyChoiceMap()).
  2. Update :b by running update(bar2_tr, (a,), (NoChange(),), StaticChoiceMap(:c => 5)). (We know the argdiff is NoChange(), since NoChange() was returned by the call to update(bar1_tr).)

Note that the update call type information is informative enough for the compiler to realize that the update of a ~ bar(x) will result in a NoChange. So if the Julia compiler is aggressive enough, it should be able to realize at compile-time it never needs to update the bar1 call, and it can encode in the compiled program that it will start by updating bar2 with a NoChange argdiff. This means the Julia compiler, by utilizing the argdiff type information implicit in our declarations of the update function, might be able to effectively calculate markov blankets at compiletime without us specifically writing code to do this! All the user has to do is make sure that as much decision-making about returning NoChange as is possible is done via method dispatch, rather than if statements. (And using something like Mjolnir or regular Julia constant-propagation may even mean that if the return type is determined by Julia control flow, we can still automatically figure out Markov blankets at compile-time.)

I don't know how heavily optimized the Julia compiler is, so I don't know how well this would work in practice. (Perhaps tools like Mjolnir are needed for as aggressive compilation to occur as we would need for this.) But this does seem like it would give us a way to handle the tradeoff between dynamic and static compilation even more automatically.

femtomc commented 4 years ago

@georgematheos this is the right idea. In practice, I don’t think the compiler will be able to determine this automatically but will instead require a few motivational shoves.

On Mjolnir, I think it’s possible to implement the transformation you described. If anything, I think Mike Innes would be happy to help/support usage of the tool here.

The fallback is to munge the IR directly. After a night of thought, I’m a little less sure about this for Gen, because of the other infrastructure required for GenerativeFunctions. In particular, the insertion of macro time code around calls means the method bodies grow somewhat large? Ignoring that, the other question is: if you munge the IR, compile it to an anonymous function, can you stick that back inside the GenerativeFunction and have it just work.

Thus, I think your suggestion of trying out Mjolnir first is a good first path. Mjolnir is very new - it’s likely we’ll run into issues related to https://github.com/MikeInnes/Mjolnir.jl/issues/3 because coverage of instrinsics and built ins is a work in progress!

alex-lew commented 4 years ago

@georgematheos I agree that's an intriguing approach. (One issue I'm not sure it addresses is how to recover the static DSL's specialized trace data structures.) It could be interesting to explore a version of the dynamic DSL designed to propagate diffed values in a way that was intelligible to (a fork of) Mjolnir, and see how far this approach could be pushed.

Another option is to explore the design of static analyses (aided by abstract interpretation / partial evaluation) of only the model code (and not the GFI implementation). These analyses could then inform the creation of trace data structures and the generation of GFI code. This implementation would not be any simpler than the existing static DSL compiler (it would likely be a lot more complex), but could potentially be more automatic for the user, and discover more interesting opportunities for incremental computation. @femtomc Maybe this is what you mean by "munge the IR"? But I'm not sure what you mean by "stick that back inside the GenerativeFunction." I don't think we could get away with a single compiled function -- you'd need to compile different code for each GFI method (and perhaps for each choice-map shape, as in the static DSL today).

femtomc commented 4 years ago

@alex-lew I should have been a bit more specific.

The fallback is to munge the IR directly. After a night of thought, I’m a little less sure about this for Gen, because of the other infrastructure required for GenerativeFunctions. In particular, the insertion of macro time code around calls means the method bodies grow somewhat large? Ignoring that, the other question is: if you munge the IR, compile it to an anonymous function, can you stick that back inside the GenerativeFunction and have it just work.

I wrote this from the point of view of my implementation of the GFI methods, which are external "context" structures and not explicit methods - so you're absolutely right in your latter comments. This is actually the approach I'm taking - but just because the infrastructure is amenable to it (see e.g. https://github.com/femtomc/Jaynes.jl/blob/master/scratch/generated_passes.jl so I likely only have to handle compilation of a single method body, and not the entire GFI - but performance characteristics are unknown yet).

From your comments here though, I seem to misunderstand the potential for optimizations - why would specialized methods be required for generate (for example)? Unless this is exactly the connection for specialized Trace structures - e.g. generate will produce an optimized trace structure (which, I'm guessing, is optimized w.r.t memory usage and access?) That makes the most sense to me. I'll return after I read the other GFI elements in the static src.

Edit: ah I see the relevant bit:

function generate_trace_struct(ir::StaticIR, trace_struct_name::Symbol, options::StaticIRGenerativeFunctionOptions)
    mutable = false
    fields = get_trace_fields(ir, options)
    field_exprs = map((f) -> Expr(:(::), f.fieldname, f.typ), fields)
    Expr(:struct, mutable, Expr(:(<:), trace_struct_name, QuoteNode(StaticIRTrace)),
         Expr(:block, field_exprs..., Expr(:(::), static_ir_gen_fn_ref, QuoteNode(Any))))
end

this generates a struct with exactly the required fields for random choices I'm guessing?

I have thought a bit about the generation of specialized trace types from a static analysis - but I certainly need to think more, as the points you raise are key issues. One place to start @georgematheos with a Mjolnir prototype would be https://github.com/MikeInnes/Mjolnir.jl/blob/master/docs/types.md to determine if the argdiffs can be encoded in the Mjolnir type extensions to the compiler. It seems like this is a smaller issue than replacing the entire static DSL, and is instead a useful optimization for the dynamic DSL.

Vague thoughts: @georgematheos the Mjolnir approach is to essentially specialize the interpreter to the provided program - so it's a sort of Futamura/specialization right? The static DSL performs the same function - but you can specialize away everything at compile time.

This is one perspective which I've been trying to wrap my head around - the dynamic DSL only allows you to partially specialize, but as you re-run the interpreter you might be able to gradually learn more about the program.

alex-lew commented 4 years ago

@femtomc Yes, I took that to be your and George's proposed approach: write a dynamic DSL interpreter (or eight, for the various GFI methods) using patterns that lend themselves to compilation-by-partial-evaluation (Futamura-style).

I think that's an interesting approach to explore, to see how far it can be pushed. I'm excited to see what you discover :-) and couldn't resist playing around with Mjolnir myself a bit yesterday. But I suspect many aspects of Static+Combinators will be difficult to get at using that approach. For example, I don't see immediately how partial evaluation could determine that

function f(n,  a)
  for i=1:n
    if i == a
      do_stuff!(i)
    end
  end
end

specializes to do_stuff!(a) no matter what n is. But note that this optimization is at the core of Map's update implementation -- visiting only some iterations of a loop. There are static analysis approaches beyond straight-forward partial evaluation that can automate this sort of thing (e.g. the "histogram transformation" in https://dl.acm.org/doi/abs/10.1145/3341702). In a project about "compiling the dynamic DSL," it could be useful to keep those options on the table.

Also, having something like the static graph around is nice. This is in part because the static computation graph can be used in ways that don't fit as neatly into the Futamura picture: graph analysis can be used to improve the statistical/inferential properties of the underlying algorithm and not just the implementation's performance.

georgematheos commented 4 years ago

Related to a comment @femtomc made on slack, to do this sort of "automatic" static compilation the dynamic DSL will need to propagate argdiffs. While there are still design questions around how/whether to do static compilation of dynamic model, I suspect this argdiff propagation might be useful for performance even if we don't try to do compilation.

One possible implementation strategy could be to change the GFI so that we don't pass in argdiffs separately from the args. Instead, we pass in Diffed values as arguments or return a Diffed value.

An example:

# current code:
update(tr, (1, 2), (NoChange(), UnknownChange()), choicemap)

# new code:
update(tr, (Diffed(1, NoChange()), Diffed(2, UnknownChange()), choicemap)
# which is equivalent to:
update(tr, (Unchanged(1), 2), choicemap)
# where `Unchanged` constructs a NoChange diffed val, and vals without a diff are treated as `UnknownChange`

If we then also updated the traceat function for update to return Diffed(get_retval(new_tr), retdiff) after every update, without changing the dynamic DSL code at all, argdiffs will automatically get propagated. Ie. if there were static calls within the dynamic function, they would know the proper argdiffs.

When the dynamic gen function includes calls to julia functions which can dispatch on Diffed values, the diffs will automatically get updated as needed. To handle julia functions which can't dispatch, we could implement automatic conversion from a Diffed to non-diffed values (effectively automatically stripping the diff and changing to NoChange).

alex-lew commented 4 years ago

@georgematheos Can you think of a good way to do that automatic conversion? (This is a challenge with the current diffs annotation too -- you just get errors if you're not careful.)

Another challenge arises with Diffed values and control flow.

For example, consider

x = if a; b; else; c; end

where a has no diff information and b and c are annotated with NoChange. Then if we are not careful, x will be labeled with NoChange, even though its value may have changed (thanks to branching). I think loops, recursion, local mutation, and all the other things that we support in the dynamic DSL could make this even more complicated.

However, it may be possible to come up with some general rules at the level of IRTools IR, e.g., which would be neat :-) I think a bar for this change would be that such a transformation should not be brittle: the user shouldn't run into weird errors in the dynamic DSL just because they're using Julia features that we hadn't foreseen.

femtomc commented 4 years ago

This problem seems like it could be addressed by a fork of IRTracker.jl. This problem is essentially self-adjusting computation, so a dynamic dependency graph approach seems like it could be expressed in that infrastructure.

I'm a bit concerned about memory usage in the dynamic case. I haven't investigated this package yet - but it appears to track everything as Expr nodes in the dynamic graph (thus, superficially resembling a dynamic version of the static IR). It would be nice if you could define custom nodes for the tracker.

I now understand this side of the static IR better - the DAG is exactly the dependency graph which you would require for incremental re-computation. I'm having a difficult time seeing how you could express this as a static pass, or even an IRTools pass (just because IRTools doesn't work with typed IR, so if you were expressed diff information at the type level, you couldn't utilize it)...unless you're talking about using a dynamo @alex-lew for example (which is how IRTracker.jl is implemented).

phipsgabler commented 4 years ago

I'm a bit concerned about memory usage in the dynamic case. I haven't investigated this package yet - but it appears to track everything as Expr nodes in the dynamic graph (thus, superficially resembling a dynamic version of the static IR). It would be nice if you could define custom nodes for the tracker.

That's about right. You might have seen that IRTracker originally had been called "DynamicComputationGraphs" :) The IR is tracked dynamically in what really is a Wengert list, but including branches and block arguments. There has been some talk about re-tracing graphs (cf. here), without any progress though.

And "defining custom nodes" would indeed be nice! Currently, what you can do is adding arbitrary metadata to the NodeInfo. But to be really flexible, one should probably do what I suggested here -- strip down IRTracker to a Cassette-like system.

For now, the code is unfortunatly only research quality in some respects. You might have noticed the huge compile times on first tracking, but even after that there's a huge performance difference. The expression representation consists of nested parametrized types, which is quite heavy on the compiler, and the way my version of "overdubbing" works still contains a type instability (essentially due to the fact that functions and their arguments are wrapped into nodes and then unpacked again -- I don't really know what's going on with the types there, sorry...).

alex-lew commented 4 years ago

@femtomc If you're interested in how dynamic computation graphs can be used for efficient trace updates, you might want to look at Venture (https://arxiv.org/pdf/1404.0099.pdf). Gen is much faster than Venture, partly thanks to Julia, and partly because of the overhead of Venture's dynamic dependency tracking. But it may help for thinking through what sort of information needs to be tracked to enable asymptotically efficient updates in fully dynamic models.

femtomc commented 4 years ago

@alex-lew thank you - I'll take a read. It seems like the propagation of argdiff information doesn't necessarily need to follow this sort of infrastructure - if there's a clever way to solve the conversion problem for methods which are diff unaware.

femtomc commented 4 years ago

@alex-lew @georgematheos Quick update on some of this thought (unfortunately not Gen-specific yet):

I just re-factored my prototype to:

  1. use IRTools instead of Cassette
  2. allow the application of a custom pass in regenerate and update contexts here. This should provide some fertile ground for exploring dynamic IR transformations. This seems complementary to propagating argdiffs - but argdiffs are still required to do the sort of efficient updates discussed above.

One thing I need to determine is if I can customize the pass using metadata in the context - this is basically the lynch pin for this sort of implementation to work.

A simple pass might be:

  1. Look at dependency graph collected in context.
  2. Determine update addresses.
  3. Cut IR above address.

This is not correct yet, because you do need argdiffinformation to determine if you actually can cut or not. But if I can get the pipeline working up to this, it's at least feasible to do some of this stuff we've been discussing.

The way this is setup, you only perform this sort of analysis at a call site before you recurse in - so you could apply something like Mjolnir before descending into the call. I'll begin experimenting with some of this runtime IR transformation stuff after the holiday. Enjoy!

georgematheos commented 4 years ago

Pasting in some thoughts from @femtomc on the Julia slack which I think are relevant to this discussion:

McCoy R. Becker 9:22 AM I wonder if the diffs issue in the dynamic DSL can be solved by a combo: A diff system which is setup like Zygote - where you have a lib of base functions defined with a "diff forward" pushforward function. A static IR (or macro!) pass which extracts if control flow into concrete ifelse calls. I'm trying to determine if this checks off the boxes which Alex raised in our earlier conversation. It seems like you should be able to handle switches over control flow, if you define the diff pushforward for ifelse, there's good base coverage, and you write a small transformation in the lang macro. (edited)