dfdx / Umlaut.jl

The Code Tracer
MIT License
32 stars 6 forks source link

Tracing on optimized IR? #40

Open ToucheSir opened 1 year ago

ToucheSir commented 1 year ago

This is a speculative issue instead of a speculative PR because I'm not sure what the implementation ought to look like, but wanted to write it down for posterity:

My understanding is that, presently, Umlaut works with typed but unoptimized IR because inlining and related optimizations might dismantle primitives before they can be detected by isprimitive. So how do some AD libraries (Enzyme and Diffractor) make custom rules/primitives work on optimized IR given only static information? It turns out that they use the AbstractInterpreter API, specifically overriding Core.Compiler.inlining_policy. This is apparently sufficient to prevent inlining of selected function calls.

So the question becomes whether Umlaut can also use this inlining control mechanism to fetch optimized typed IR. Functions like code_typed already support passing in a custom AbstractInterpreter. The main challenges I see are that isprimitive assumes you're passing values instead of just types, the compiler APIs used are somewhat complex and compiler internals shift (i.e. break backwards compat) quite a bit across different Julia versions.

dfdx commented 1 year ago

I see two approaches.

A simple approach would be to override getcode() to use a custom interpreter to extract code. It sounds doable. However, if the goal of tracing the optimized code is the speed, then changing only getcode() would not help much.

A more sophisticated approach is to rewrite the whole tracing making it fully static. It has many advantages, but it contradicts the main idea of Umlaut - capturing the exact execution trace of a program. Consider this function:

function foo(x)
    if x > 0
       return bar(x)
    else
       return baz(x)
end

Currently, Umlaut will use exact value of x to continue tracing of either bar() or baz(). Static type information is not sufficient to understand what branch to take. We could go the JAX way and introduce special conditional operators (like jax.lax.cond) or require to mark such values static (like static_argnums). But then we would not be able to trace arbitrary Julia code, so it's quite a radical change.

With both approaches, a huge problem is the lack of documentation for AbstractInterpreter. I tried to use it several times, but with only a few examples in other packages using AbstractIntrerpreter requires a much deeper dive into the Julia compiler than I have or can afford to learn. Yet, if someone wants to try it, I'll be more than happy to describe in detail a few possible designs.

ToucheSir commented 1 year ago

I agree WRT not going the fully static route. Pragmatically, I don't think it even works right now because custom AbstractInterpreters are lost during dynamic dispatch. The tricky part of the getcode approach would be that it'd also need an equivalent to isprimitive (but which works on fully static signatures). Seems too clunky to duplicate primitive detection logic in possibly two places, but I can't think of any other way to do things.

With both approaches, a huge problem is the lack of documentation for AbstractInterpreter. I tried to use it several times, but with only a few examples in other packages using AbstractIntrerpreter requires a much deeper dive into the Julia compiler than I have or can afford to learn. Yet, if someone wants to try it, I'll be more than happy to describe in detail a few possible designs.

Yes, I would've tried if not for this limitation. Anecdotally, it appears a lot of people who would be interested in interfacing with the compiler suffer from the lack of stable, documented interfaces in this area.

yebai commented 1 year ago

However, if the goal of tracing the optimized code is the speed, then changing only getcode() would not help much.

I got curious about this. Intuitively, tracing optimised typed IR will result in an optimised tape. Can you clarify what other changes are required to get better speed?

In addition, a mixed approach blending the fully static and value-based tracing was explored in the Mjolnir.jl package. Would similar ideas help find a tradeoff between flexibility and performance?

ToucheSir commented 1 year ago

I don't quite understand either. Another benefit would be avoiding some of the gnarlier internal functions such as _apply(_iterate) which are usually compiled away in normal code.

In addition, a mixed approach blending the fully static and value-based tracing was explored in the Mjolnir.jl package. Would similar ideas help find a tradeoff between flexibility and performance?

Possibly, though one practical concern for Mjolnir.jl and more recent approaches like https://github.com/JuliaCompilerPlugins/Purple.jl is that they rely heavily on compiler internals/invent a bunch of their own. In both cases, I think the maintenance burden ended up being part of what killed the project. Not having access to runtime values also has some edge cases around dynamic dispatch, see e.g. https://github.com/FluxML/Mjolnir.jl/issues/4.

dfdx commented 1 year ago

I got curious about this. Intuitively, tracing optimised typed IR will result in an optimised tape. Can you clarify what other changes are required to get better speed?

There are two types of speed involved:

  1. Speed of tracing, i.e. how fast trace(f, args...) itself works. This is what I meant in the first comment. The last time I profiled the tracer, most time was spent on low-level compiler stuff like type inference. getcode() takes a fraction of time, but it's certainly not the biggest performance issue during the tracing.
  2. Speed of running the tape. Or rather running the code generated from the tape. By default, i.e. when you run play!(tape) or in Yota.jl, tape is compiled back to a Julia function. In theory, Julia compiler should use all the same optimizations and get approximately the same speed. Also, if you compile the tape into something more specific (e.g. ONNX), you can get even better and more specialized optimizations. In any case, I would not expect the optimizations in the original types IR to have large effect on optimizations in the final IR.

Note that I mostly think about use cases like automatic differentiation and other code-to-graph-to-code transformations. It's pretty likely there are other use cases where your assumptions about optimized IR will be true.

ToucheSir commented 1 year ago

Thanks for the explanation. I presume we're all mostly thinking about use cases like AD, which is how this came up in the first place. One of the big takeaways from the first Enzyme paper is that running AD on already optimized code (i.e. optimize -> AD transform -> optimize) can result in significantly higher performance. We can see this in practice in the Julia ecosystem with Zygote, which has to work on IR sans any optimizations. Now something like Yota is not as susceptible to this because of its design, but it still is affected because higher-level rrules can act as optimization fences and ChainRules (necessarily) contains many such rules.

willtebbutt commented 1 year ago

Speed of tracing, i.e. how fast trace(f, args...) itself works. This is what I meant in the first comment. The last time I profiled the tracer, most time was spent on low-level compiler stuff like type inference. getcode() takes a fraction of time, but it's certainly not the biggest performance issue during the tracing.

I can well imagine that this is the case when you're using the BaseCtx and don't wind up with too many primitives on the tape. My use-case is slightly different though, and I tend to wind up with 1000s of primitives on the tape. I did a (very rough) back of the envelope calculation for one example with about 6000 items on the tape, and I believe that about 2/3rds of them would generally be removed by operating on type IR (for example, letting the compiler optimise away things like apply_type, getglobal / getproperty(::Module, name), inserting items into Tuples and removing them immediately). I think this example is quite typical of my use-case, so I would very much like to ensure that we can trace optimised IR.

@dfdx would you be open to PR which requires that isprimitive is a function only of a signature, rather than of the values of the elements? Taking a look through the code, it looks like isprimitive for BaseCtx is already effectively only using type information anyway 🤷

edit: the docstring for isprimitive actually current suggests that the signature is the important thing, rather than the values of the arguments:

help?> Umlaut.isprimitive
  isprimitive(ctx::BaseCtx, f, args...)

  The default implementation of isprimitive used in trace(). Returns true if the
  method with the provided signature is defined in one of the Julia's built-in
  modules, e.g. Base, Core, Broadcast, etc.

  ──────────────────────────────────────────────────────────────────────────────────

  isprimitive(ctx::Any, f, args...)

  Fallback implementation of isprimitive(), behaves the same way as
  isprimitive(BaseCtx(), f, args...).
dfdx commented 1 year ago

Yes, in general I'm open to such a PR. But since it's going to be a breaking change, we will need some time and effort to evaluate possible consequences and implement proper semantic versioning.