Open mcabbott opened 2 days ago
Thanks for this -- not premature at all. My guess is that Zygote (I assume that's what Flux is using by default still?) has rules for everything here, but Mooncake is deriving its own rule in some performance-critical case. I'll have to take a look at a profile and see what rule(s) I'm missing that are causing poor performance.
Yes to Zygote.
The first should not need very exotic rules, matrix multiplication and .+
, tanh.(_)
, sum(abs2, _)
. The second uses conv from NNlib.
Maybe worth including an even simpler Flux-free example:
julia> @btime Flux.gradient(x -> sum(abs2, x), $img)[1][1:3] # really Zygote
min 20.875 μs, mean 48.610 μs (4 allocations, 392.19 KiB)
3-element Vector{Float32}:
1.7552974
1.0280211
0.57361186
julia> let f = x -> sum(abs2, x)
backend = DifferentiationInterface.AutoMooncake(; config=nothing)
prep = @btime DifferentiationInterface.prepare_gradient($f, $backend, $img)
grad = @btime DifferentiationInterface.gradient($f, $prep, $backend, $img)
grad[1:3]
end
min 2.475 ms, mean 3.286 ms (6687 allocations, 12.59 MiB)
min 812.959 μs, mean 866.172 μs (5 allocations, 392.44 KiB)
3-element Vector{Float32}:
1.7552974
1.0280211
0.57361186
and Enzyme, without/with pre-allocating space for the gradient:
julia> @btime Enzyme.gradient(Reverse, x -> sum(abs2, x), $img)[1][1:3]
min 77.083 μs, mean 103.718 μs (4 allocations, 392.19 KiB)
3-element Vector{Float32}:
1.7552974
1.0280211
0.57361186
julia> @btime Enzyme.autodiff(Reverse, x -> sum(abs2, x), Active, $(Duplicated(img, zero.(img))));
min 75.625 μs, mean 82.092 μs (0 allocations)
Yeah, I agree that it ought not to require anything particularly exotic. Mooncake has rules for gemm!
and NNlib.conv
(or whatever the exact operation is), so it must be something to do with the activations that is causing trouble.
In terms of the performance you're seeing from Mooncake here, there are two things happening in this example:
sum(sin, x)
rather than sum(abs2, x)
, you'd see that Enzyme and Mooncake have quite similar performance (my point isn't that this is something you would actually want to do, but rather that the operation doesn't have to be that expensive at each iteration for Mooncake to have good performance without deriving a rule).To get a sense of the amount of time taken allocating memory:
julia> using BenchmarkTools, Mooncake
julia> img = rand(Float32, 28, 28, 1, 128);
julia> f(x) = sum(abs2, x)
f (generic function with 1 method)
julia> rule = build_rrule(f, img);
julia> @benchmark Mooncake.value_and_gradient!!($rule, f, $img)
BenchmarkTools.Trial: 6059 samples with 1 evaluation.
Range (min … max): 694.000 μs … 20.780 ms ┊ GC (min … max): 0.00% … 95.46%
Time (median): 750.042 μs ┊ GC (median): 0.00%
Time (mean ± σ): 823.888 μs ± 549.142 μs ┊ GC (mean ± σ): 2.50% ± 3.75%
▆▆█▂ ▂▁
▆▆████▇▅▄▃▃▃▂▂▂▂▂▃▅▅███▇▅▄▄▃▃▂▂▂▃▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▂▂▁▂▂ ▃
694 μs Histogram: frequency by time 1.17 ms <
Memory estimate: 392.41 KiB, allocs estimate: 5.
julia> @benchmark Mooncake.__value_and_gradient!!($rule, zero_codual(f), $(zero_codual(img)))
BenchmarkTools.Trial: 6520 samples with 1 evaluation.
Range (min … max): 689.917 μs … 22.167 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 709.292 μs ┊ GC (median): 0.00%
Time (mean ± σ): 765.979 μs ± 278.923 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
█▆▆▅▅▄▂▁ ▁▄▅▅▅▄▄▃▃▃▂▁ ▂
████████▇▆▅▃▄▄▃▅▄▄▄▂▃▃▅▂▅▅▅▆██████████████▇▇▇▆▅▆▅▆▆▆▆▇▆▇▆▇▆▇▆ █
690 μs Histogram: log(frequency) by time 972 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
If I add the following rule to Mooncake:
julia> function Mooncake.rrule!!(::CoDual{typeof(sum)}, ::CoDual{typeof(abs2)}, x::CoDual{<:Array{P}}) where {P<:IEEEFloat}
function sum_abs2_pb(dy::P)
x.dx .+= (2 * dy) .* x.x
return NoRData(), NoRData(), NoRData()
end
return zero_fcodual(sum(abs2, x.x)), sum_abs2_pb
end
julia> Mooncake.@is_primitive Mooncake.DefaultCtx Tuple{typeof(sum), typeof(abs2), Array{<:IEEEFloat}}
and recompute rule
, the timings I get for Mooncake.__value_and_gradient!!
are:
julia> @benchmark Mooncake.__value_and_gradient!!($rule, zero_codual(f), $(zero_codual(img)))
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 44.166 μs … 370.417 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 45.750 μs ┊ GC (median): 0.00%
Time (mean ± σ): 49.663 μs ± 9.479 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▅▅█▆▃ ▁▂▂▃▄▄▄▄▃▂▁ ▁
██████▇▇▇▇▇▆██████████████▇▇█▇▇▆▆▆▆▇▆▆▆▆▆▅▆▅▆▆▅▅▆▆▆▅▄▅▅▅▅▅▄▅ █
44.2 μs Histogram: log(frequency) by time 76.4 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
which shows that Mooncake is leaving quite a lot on the table in this case, and highlights just how fast Enzyme is in this situation -- getting to within a factor of two of hand-written performance without writing a rule is quite impressive.
That being said, the primal only takes a few microseconds:
julia> @benchmark f($img)
BenchmarkTools.Trial: 10000 samples with 5 evaluations.
Range (min … max): 6.892 μs … 31.825 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 6.917 μs ┊ GC (median): 0.00%
Time (mean ± σ): 7.236 μs ± 855.013 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
█ ▄▂ ▂▁▁▁▁▂▁▁▁ ▁
█▆██▅▅▅▆▆▅▅▅▃▄▅▅▇▇▇▇▇▆▇▆▇▇▇████████████▅▆▅▆▆▅▆▄▄▅▄▃▄▃▅▃▅▅▅▄ █
6.89 μs Histogram: log(frequency) by time 9.49 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
so I suspect that even my hand-written rule isn't optimal. I wonder if sum(abs2, x)
is doing some SIMD stuff that my hand-written rrule!!
is not doing for some reason.
Re pre-allocation, sure, I showed Zygote/Enzyme without this too. To first approximation the minimum time doesn't see this, here.
Sure, in a sense Zygote is cheating because it has a rule for exactly x -> sum(abs2, x)
and that's the whole function. Almost any AD system could be similarly engineered to beat the benchmark. Maybe the bigger question is how wide you eventually picture Mooncake's rules spreading -- it would certainly be ideal to need fewer than Zygote does, replacing hand-written with automatic. If that is the goal, then tests like this are how you measure progress.
What's the vision here? Do you think that, long-term, operations like NNlib.relu.(x)
ought to have custom rules defined?
For x -> sum(abs2, x)
, here's the rule-definition you describe, with I think everything qualified etc.
julia> using BenchmarkTools, Mooncake, Zygote
julia> img = rand(Float32, 28, 28, 1, 128);
julia> f(x) = sum(abs2, x)
f (generic function with 1 method)
julia> @btime f($img);
5.667 μs (0 allocations: 0 bytes)
julia> @btime Zygote.gradient(f, $img);
21.292 μs (3 allocations: 392.09 KiB)
julia> rule = build_rrule(f, img);
julia> @btime Mooncake.value_and_gradient!!($rule, f, $img);
630.542 μs (5 allocations: 392.41 KiB)
julia> @btime Mooncake.__value_and_gradient!!($rule, zero_codual(f), $(zero_codual(img)));
621.708 μs (0 allocations: 0 bytes)
julia> function Mooncake.rrule!!(::Mooncake.CoDual{typeof(sum)}, ::Mooncake.CoDual{typeof(abs2)}, x::Mooncake.CoDual{<:Array{P}}) where {P<:Mooncake.IEEEFloat}
function sum_abs2_pb(dy::P)
x.dx .+= (2 * dy) .* x.x
return NoRData(), NoRData(), NoRData()
end
return Mooncake.zero_fcodual(sum(abs2, x.x)), sum_abs2_pb
end
julia> Mooncake.@is_primitive Mooncake.DefaultCtx Tuple{typeof(sum), typeof(abs2), Array{<:Mooncake.IEEEFloat}}
julia> rule2 = build_rrule(f, img);
julia> @btime Mooncake.__value_and_gradient!!($rule2, Mooncake.zero_codual(f), $(Mooncake.zero_codual(img)));
37.708 μs (0 allocations: 0 bytes) # on Julia 1.10, I got 22.375 μs
(jl_MaAIiw) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_MaAIiw/Project.toml`
[da2b9cff] Mooncake v0.4.36
julia> VERSION
v"1.11.0"
If, for example, you were doing
sum(sin, x)
rather thansum(abs2, x)
, you'd see that Enzyme and Mooncake have quite similar performance
Ok, I tried it. With a slightly harder function than sin
, for force AD to get its hands dirty.
Here I think Zygote is using ForwardDiff, because that's what it does. (I don't know whether Enzyme & Mooncake switch to forward mode for broadcasting.)
using Zygote, Enzyme, Mooncake
img = rand(Float32, 28, 28, 1, 128);
red(x) = sum(x -> atan(x, 2f0), x); # sum(f,x) which is is certain not to have an existing rule
red2(x) = sum(atan.(x, 2f0)); # implementation using broadcast instead
let
g1 = @btime Zygote.gradient(red, $img) # red(x)=sum(sin,x) allocates less memory here, 1.15 MiB
g2 = @btime Zygote.gradient(red2, $img) # also here, 1.53 MiB
g2[1][1:3]
end
# min 510.625 μs, mean 647.758 μs (39 allocations, 1.92 MiB)
# min 645.250 μs, mean 855.522 μs (52 allocations, 2.30 MiB)
let
g1 = @btime Enzyme.gradient(Reverse, red, $img) # this is the only large change with red(x)=sum(sin,x) instead, -> 290.916 μs
g2 = @btime Enzyme.gradient(Reverse, red2, $img)
g2[1][1:3]
end
# min 77.625 μs, mean 104.632 μs (3 allocations, 392.11 KiB)
# min 939.375 μs, mean 1.051 ms (8 allocations, 1.15 MiB)
let
rule = build_rrule(red, img); # this step min 551.914 ns, mean 603.759 ns (10 allocations, 488 bytes)
g1 = @btime Mooncake.value_and_gradient!!($rule, $red, $img);
rule2 = build_rrule(red2, img);
g2 = @btime Mooncake.value_and_gradient!!($rule2, $red2, $img);
g2[2][2][1:3]
end
# min 926.291 μs, mean 1.023 ms (5 allocations, 392.44 KiB)
# min 2.521 ms, mean 2.672 ms (13 allocations, 1.15 MiB)
Sure, in a sense Zygote is cheating because it has a rule for exactly x -> sum(abs2, x) and that's the whole function. Almost any AD system could be similarly engineered to beat the benchmark. Maybe the bigger question is how wide you eventually picture Mooncake's rules spreading -- it would certainly be ideal to need fewer than Zygote does, replacing hand-written with automatic. If that is the goal, then tests like this are how you measure progress.
Agree with the points you make here. Regarding rules -- yes, I definitely anticipate requiring many fewer rules in the short to medium term. Also agree re tracking performance -- if you take a look a Mooncake PR (eg. https://github.com/compintell/Mooncake.jl/pull/362) you'll see a table of benchmark numbers that get automatically printed into the PR -- they track the AD time / primal time for Mooncake and some others for a range of functions. In particular you'll see results for various variants of sum
and kron
which have really quite poor numbers. These are some near-worst case (I believe) numbers, and so they're roughly an upper bound on how poorly we should expect Mooncake to perform on loops written in Julia.
What's the vision here? Do you think that, long-term, operations like NNlib.relu.(x) ought to have custom rules defined?
Long-term I would hope not, but in the short-term they're going to continue to need them (sadly).
Regarding the numbers you're seeing: I see similar numbers for Mooncake locally. Enzyme is doing something truly impressive to get min 77.625 μs
, because the primal takes roughly 400 microseconds to run (on my machine), so it must have managed to completely eliminate the primal computation in this case. I assume this is its activity analysis doing its job.
For context, Mooncake doesn't have any rules for higher order functions, and doesn't have any special tricks to make use of forwards-mode AD in situations where it might be advantageous to do so, so what's you're seeing here is just Mooncake doing reverse-mode AD on Julia IR. That is, it's ADing this just with a rule for atan
and low-level rules for array operations (arrayref
/ arrayset
on 1.10, and the Memory
and MemoryRef
equivalents on 1.11). With that in mind I'm pretty pleased with the performance, but I acknowledge it's got a way to go before it's at a level that people would consider "good".
Moving forwards, is there any chance of the Flux maintainers publishing a set of functions / models that they say "this is the core of the library, and is the bit that really needs to be fast"? In the short term I am quite interested in ensuring people can get good performance using Mooncake for fairly standard DL models, even if the performance drops off a bit for less standard stuff.
Enzyme is doing something truly impressive to get min 77.625 μs
Indeed, quicker than the primal for me too. Very clever but probably not the number we want. ReverseWithPrimal
makes it slower than Zygote.
is there any chance of the Flux maintainers publishing a set of functions / models that they say "this is the core of the library, and is the bit that really needs to be fast"?
We're not very systematic, but these two are the simplest cases in the model zoo. I have not tried any GPU models.
https://github.com/FluxML/Flux.jl/pull/2471 wants to make it super-easy to try Enzyme instead of Zygote. It would be nice to do something similar for Mooncake. Since both (I believe) are happiest when the gradient is allocated up front, making some comparably easy way to do that might be nice.
https://github.com/FluxML/Flux.jl/pull/2471 wants to make it super-easy to try Enzyme instead of Zygote. It would be nice to do something similar for Mooncake. Since both (I believe) are happiest when the gradient is allocated up front, making some comparably easy way to do that might be nice.
Correct -- preallocation of gradients storage will tend to give the best performance.
So the ideal Flux would be for something like Enzyme's Duplicated to be available:
struct MoonPair{X,DX}; x::X; dx::DX; end
MoonPair(x) = MoonPair(x, Mooncake.zero_codual(x))
Thinking about extensions etc, it would actually be simplest for Flux to own this. But perhaps that's weird.
Maybe this is the wrong thread for such interface discussion, though. Maybe https://github.com/FluxML/Flux.jl/pull/2471 is the right one?
I don't know whether it's premature to do so, but since we're thinking about how Flux interacts with AD, I tried out a few very simple cases. Exactly the same cases as in https://github.com/EnzymeAD/Enzyme.jl/issues/2069 : CPU only, Float32, one model without calling
NNlib.conv
and one with.Versions: