Closed dfdx closed 1 year ago
Thanks to all the commenters in this thread, tracing now works ~2x times faster. I also fixed a performance bug in todo_list()
, and now the whole grad(loss, model, image)
compiles and runs in 61 second, which is reasonably good.
Here I'm going to track the performance issues with Yota
Starter code
```julia using Yota using Yota.Umlaut using Metalhead using Profile using ProfileView loss(model, image) = sum(model(image)) function main() model = Metalhead.ResNet(18) image = rand(Float32, 224, 224, 3, 1) @time model(image) @time trace(model, image; ctx=GradCtx()) @profile grad(loss, model, image) Profile.print(mincount=100) end ```Currently, the gradient of ResNet takes forever (or at least > 30 minutes). Function execution takes ~10 seconds, tracing it - 60 seconds, so most of the time is spent in
grad()
.Profiler output after 2 minutes of execution
My current interpretation is as follows:
rrule
s.rrule
s invokerrule_via_ad
, which, in their turn, trigger tracing of the argument function. So the bottleneck is indeed tracing, though not the initial tracer pass.mkcall()
).@nospecialize
helped a bit in tests, but is definitely not a game changer.In the original design, Yota wasn't supposed to trigger compilation during backpropagation. In fact, the design was very similar to JAX, with the only exception that we used IR-level tracing instead of operator overloading due to issues in multiple dispatch. Just to recap, the first versions of Yota worked like this:
So exactly one tracing and one compilation. However, prevalence of ChainRules changed the game. The current design looks like this:
y = f(xs...)
withy, pb = rrule(f, xs...)
. Save pullbacks for the next step.Now Yota has no control over what happens in an
rrule
and, in case of higher-order functions, cannot avoid additional tracing and compilation. Since Flux uses higher-order functions extensively, we get what we get.As far as I can see, the only way forward is to speed up tracing. However, it requires a really, really good understanding of the Julia compiler, which I don't have. To my knowledge, the only autodiff package that managed to do it is Diffractor, and not too many people understand how it works.
So I'm pretty much puzzled.