dfdx / Yota.jl

Reverse-mode automatic differentiation in Julia
MIT License
158 stars 12 forks source link

Performance analysis #131

Closed dfdx closed 1 year ago

dfdx commented 1 year ago

Here I'm going to track the performance issues with Yota

Starter code ```julia using Yota using Yota.Umlaut using Metalhead using Profile using ProfileView loss(model, image) = sum(model(image)) function main() model = Metalhead.ResNet(18) image = rand(Float32, 224, 224, 3, 1) @time model(image) @time trace(model, image; ctx=GradCtx()) @profile grad(loss, model, image) Profile.print(mincount=100) end ```

Currently, the gradient of ResNet takes forever (or at least > 30 minutes). Function execution takes ~10 seconds, tracing it - 60 seconds, so most of the time is spent in grad().

Profiler output after 2 minutes of execution

My current interpretation is as follows:

In the original design, Yota wasn't supposed to trigger compilation during backpropagation. In fact, the design was very similar to JAX, with the only exception that we used IR-level tracing instead of operator overloading due to issues in multiple dispatch. Just to recap, the first versions of Yota worked like this:

  1. Trace a function to turn it into a list of primitives for which we know the differentiation rules (forward pass).
  2. Record these differentiation rules to the same tape (backward pass).
  3. Compile the tape.

So exactly one tracing and one compilation. However, prevalence of ChainRules changed the game. The current design looks like this:

  1. Trace a function.
  2. Replace all primitive calls y = f(xs...) with y, pb = rrule(f, xs...). Save pullbacks for the next step.
  3. Record pullback invocations.
  4. Compile the tape.

Now Yota has no control over what happens in an rrule and, in case of higher-order functions, cannot avoid additional tracing and compilation. Since Flux uses higher-order functions extensively, we get what we get.

As far as I can see, the only way forward is to speed up tracing. However, it requires a really, really good understanding of the Julia compiler, which I don't have. To my knowledge, the only autodiff package that managed to do it is Diffractor, and not too many people understand how it works.

So I'm pretty much puzzled.

dfdx commented 1 year ago

Thanks to all the commenters in this thread, tracing now works ~2x times faster. I also fixed a performance bug in todo_list(), and now the whole grad(loss, model, image) compiles and runs in 61 second, which is reasonably good.