compintell / Mooncake.jl

https://compintell.github.io/Mooncake.jl/
MIT License
135 stars 7 forks source link

Benchmarking some very simple Flux models #361

Open mcabbott opened 2 days ago

mcabbott commented 2 days ago

I don't know whether it's premature to do so, but since we're thinking about how Flux interacts with AD, I tried out a few very simple cases. Exactly the same cases as in https://github.com/EnzymeAD/Enzyme.jl/issues/2069 : CPU only, Float32, one model without calling NNlib.conv and one with.

julia> using Flux, BenchmarkTools, Mooncake, DifferentiationInterface

julia> mlp = Chain(Flux.flatten, Dense(28^2 => 32, tanh), Dense(32 => 10));

julia> img = rand32(28, 28, 1, 128);

julia> Flux.gradient((m,x) -> sum(abs2, m(x)), mlp, img)[1].layers[2].bias[1:3]
3-element Vector{Float32}:
  -20.508924
    1.8868564
 -105.714165

julia> let f = (m,x) -> sum(abs2, m(x))
         backend = DifferentiationInterface.AutoMooncake(; config=nothing)
         prep = DifferentiationInterface.prepare_gradient(f, backend, mlp, Constant(img))
         grad = DifferentiationInterface.gradient(f, prep, backend, mlp, Constant(img))
         grad.fields.layers[2].fields.bias[1:3]
       end
3-element Vector{Float32}:
  -20.508917
    1.8868563
 -105.71417

julia> @btime $mlp($img);
  min 10.958 μs, mean 14.119 μs (6 allocations, 43.09 KiB)

julia> @btime Flux.gradient((m,x) -> sum(abs2, m(x)), $mlp, $img);
  min 38.250 μs, mean 67.356 μs (86 allocations, 596.27 KiB)

julia> let f = (m,x) -> sum(abs2, m(x))
         backend = DifferentiationInterface.AutoMooncake(; config=nothing)
         prep = @btime DifferentiationInterface.prepare_gradient($f, $backend, $mlp, Constant($img))
         grad = @btime DifferentiationInterface.gradient($f, $prep, $backend, $mlp, Constant($img))
       end;
  min 572.959 μs, mean 846.168 μs (477 allocations, 3.83 MiB)
  min 176.708 μs, mean 237.309 μs (37 allocations, 558.14 KiB)

# a slightly bigger model

julia> lenet = Chain(  # from the model zoo
           Conv((5, 5), 1=>6, relu),
           MaxPool((2, 2)),
           Conv((5, 5), 6=>16, relu),
           MaxPool((2, 2)),
           Flux.flatten,
           Dense(256 => 120, relu),
           Dense(120 => 84, relu), 
           Dense(84 => 10),
       );

julia> Flux.gradient((m,x) -> sum(abs2, m(x)), lenet, img)[1].layers[1].bias[1:3]
3-element Vector{Float32}:
  12.092611
 106.36434
  34.718273

julia> let f = (m,x) -> sum(abs2, m(x))
         backend = DifferentiationInterface.AutoMooncake(; config=nothing)
         prep = DifferentiationInterface.prepare_gradient(f, backend, lenet, Constant(img))
         grad = DifferentiationInterface.gradient(f, prep, backend, lenet, Constant(img))
         grad.fields.layers[1].fields.bias[1:3]
       end
3-element Vector{Float32}:
  12.0926075
 106.36419
  34.718315

julia> @btime Flux.gradient((m,x) -> sum(abs2, m(x)), $lenet, $img);
  min 4.979 ms, mean 6.300 ms (558 allocations, 14.18 MiB)

julia> let f = (m,x) -> sum(abs2, m(x))
         backend = DifferentiationInterface.AutoMooncake(; config=nothing)
         prep = @btime DifferentiationInterface.prepare_gradient($f, $backend, $lenet, Constant($img))
         grad = @btime DifferentiationInterface.gradient($f, $prep, $backend, $lenet, Constant($img))
       end;
  min 71.869 ms, mean 80.095 ms (1791 allocations, 122.64 MiB)
  min 21.686 ms, mean 22.537 ms (434 allocations, 10.38 MiB)

Versions:

(jl_dBnVSh) pkg> st Mooncake
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_dBnVSh/Project.toml`
  [da2b9cff] Mooncake v0.4.36

julia> versioninfo()
Julia Version 1.10.4
Commit 48d4fd48430 (2024-06-04 10:41 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 11 × Apple M3 Pro
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, apple-m1)
Threads: 4 default, 0 interactive, 2 GC (on 5 virtual cores)
Environment:
  JULIA_NUM_THREADS = 4
willtebbutt commented 2 days ago

Thanks for this -- not premature at all. My guess is that Zygote (I assume that's what Flux is using by default still?) has rules for everything here, but Mooncake is deriving its own rule in some performance-critical case. I'll have to take a look at a profile and see what rule(s) I'm missing that are causing poor performance.

mcabbott commented 2 days ago

Yes to Zygote.

The first should not need very exotic rules, matrix multiplication and .+, tanh.(_), sum(abs2, _). The second uses conv from NNlib.

Maybe worth including an even simpler Flux-free example:

julia> @btime Flux.gradient(x -> sum(abs2, x), $img)[1][1:3]  # really Zygote
  min 20.875 μs, mean 48.610 μs (4 allocations, 392.19 KiB)
3-element Vector{Float32}:
 1.7552974
 1.0280211
 0.57361186

julia> let f = x -> sum(abs2, x)
         backend = DifferentiationInterface.AutoMooncake(; config=nothing)
         prep = @btime DifferentiationInterface.prepare_gradient($f, $backend, $img)
         grad = @btime DifferentiationInterface.gradient($f, $prep, $backend, $img)
         grad[1:3]
       end
  min 2.475 ms, mean 3.286 ms (6687 allocations, 12.59 MiB)
  min 812.959 μs, mean 866.172 μs (5 allocations, 392.44 KiB)
3-element Vector{Float32}:
 1.7552974
 1.0280211
 0.57361186

and Enzyme, without/with pre-allocating space for the gradient:

julia> @btime Enzyme.gradient(Reverse, x -> sum(abs2, x), $img)[1][1:3]
  min 77.083 μs, mean 103.718 μs (4 allocations, 392.19 KiB)
3-element Vector{Float32}:
 1.7552974
 1.0280211
 0.57361186

julia> @btime Enzyme.autodiff(Reverse, x -> sum(abs2, x), Active, $(Duplicated(img, zero.(img))));
 min 75.625 μs, mean 82.092 μs (0 allocations)
willtebbutt commented 2 days ago

Yeah, I agree that it ought not to require anything particularly exotic. Mooncake has rules for gemm! and NNlib.conv (or whatever the exact operation is), so it must be something to do with the activations that is causing trouble.

In terms of the performance you're seeing from Mooncake here, there are two things happening in this example:

  1. DifferentiationInterface.jl's preparation does not currently (I don't think) pre-allocate the memory for Mooncake.
  2. Mooncake.jl still has a bit more overhead associated to it than Enzyme does (but does comfortably beat Zygote / ReverseDiff). The majority of the poor performance here stems from the fact that the operations performed at each iteration of the loop used to compute the reduction is very cheap, so the overhead associated to Mooncake dominates performance. If, for example, you were doing sum(sin, x) rather than sum(abs2, x), you'd see that Enzyme and Mooncake have quite similar performance (my point isn't that this is something you would actually want to do, but rather that the operation doesn't have to be that expensive at each iteration for Mooncake to have good performance without deriving a rule).

To get a sense of the amount of time taken allocating memory:

julia> using BenchmarkTools, Mooncake

julia> img = rand(Float32, 28, 28, 1, 128);

julia> f(x) = sum(abs2, x)
f (generic function with 1 method)

julia> rule = build_rrule(f, img);

julia> @benchmark Mooncake.value_and_gradient!!($rule, f, $img)
BenchmarkTools.Trial: 6059 samples with 1 evaluation.
 Range (min … max):  694.000 μs …  20.780 ms  ┊ GC (min … max): 0.00% … 95.46%
 Time  (median):     750.042 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   823.888 μs ± 549.142 μs  ┊ GC (mean ± σ):  2.50% ±  3.75%

    ▆▆█▂               ▂▁                                        
  ▆▆████▇▅▄▃▃▃▂▂▂▂▂▃▅▅███▇▅▄▄▃▃▂▂▂▃▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▂▂▁▂▂ ▃
  694 μs           Histogram: frequency by time         1.17 ms <

 Memory estimate: 392.41 KiB, allocs estimate: 5.

julia> @benchmark Mooncake.__value_and_gradient!!($rule, zero_codual(f), $(zero_codual(img)))
BenchmarkTools.Trial: 6520 samples with 1 evaluation.
 Range (min … max):  689.917 μs …  22.167 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     709.292 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   765.979 μs ± 278.923 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █▆▆▅▅▄▂▁                    ▁▄▅▅▅▄▄▃▃▃▂▁                      ▂
  ████████▇▆▅▃▄▄▃▅▄▄▄▂▃▃▅▂▅▅▅▆██████████████▇▇▇▆▅▆▅▆▆▆▆▇▆▇▆▇▆▇▆ █
  690 μs        Histogram: log(frequency) by time        972 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.

If I add the following rule to Mooncake:

julia> function Mooncake.rrule!!(::CoDual{typeof(sum)}, ::CoDual{typeof(abs2)}, x::CoDual{<:Array{P}}) where {P<:IEEEFloat}
           function sum_abs2_pb(dy::P)
               x.dx .+= (2 * dy) .* x.x
               return NoRData(), NoRData(), NoRData()
           end
           return zero_fcodual(sum(abs2, x.x)), sum_abs2_pb
       end

julia> Mooncake.@is_primitive Mooncake.DefaultCtx Tuple{typeof(sum), typeof(abs2), Array{<:IEEEFloat}}

and recompute rule, the timings I get for Mooncake.__value_and_gradient!! are:

julia> @benchmark Mooncake.__value_and_gradient!!($rule, zero_codual(f), $(zero_codual(img)))
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  44.166 μs … 370.417 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     45.750 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   49.663 μs ±   9.479 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▅▅█▆▃         ▁▂▂▃▄▄▄▄▃▂▁                                    ▁
  ██████▇▇▇▇▇▆██████████████▇▇█▇▇▆▆▆▆▇▆▆▆▆▆▅▆▅▆▆▅▅▆▆▆▅▄▅▅▅▅▅▄▅ █
  44.2 μs       Histogram: log(frequency) by time      76.4 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.

which shows that Mooncake is leaving quite a lot on the table in this case, and highlights just how fast Enzyme is in this situation -- getting to within a factor of two of hand-written performance without writing a rule is quite impressive.

That being said, the primal only takes a few microseconds:

julia> @benchmark f($img)
BenchmarkTools.Trial: 10000 samples with 5 evaluations.
 Range (min … max):  6.892 μs …  31.825 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     6.917 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   7.236 μs ± 855.013 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █ ▄▂                         ▂▁▁▁▁▂▁▁▁                      ▁
  █▆██▅▅▅▆▆▅▅▅▃▄▅▅▇▇▇▇▇▆▇▆▇▇▇████████████▅▆▅▆▆▅▆▄▄▅▄▃▄▃▅▃▅▅▅▄ █
  6.89 μs      Histogram: log(frequency) by time      9.49 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.

so I suspect that even my hand-written rule isn't optimal. I wonder if sum(abs2, x) is doing some SIMD stuff that my hand-written rrule!! is not doing for some reason.

mcabbott commented 2 days ago
  1. Re pre-allocation, sure, I showed Zygote/Enzyme without this too. To first approximation the minimum time doesn't see this, here.

  2. Sure, in a sense Zygote is cheating because it has a rule for exactly x -> sum(abs2, x) and that's the whole function. Almost any AD system could be similarly engineered to beat the benchmark. Maybe the bigger question is how wide you eventually picture Mooncake's rules spreading -- it would certainly be ideal to need fewer than Zygote does, replacing hand-written with automatic. If that is the goal, then tests like this are how you measure progress.

What's the vision here? Do you think that, long-term, operations like NNlib.relu.(x) ought to have custom rules defined?

For x -> sum(abs2, x), here's the rule-definition you describe, with I think everything qualified etc.

julia> using BenchmarkTools, Mooncake, Zygote

julia> img = rand(Float32, 28, 28, 1, 128);

julia> f(x) = sum(abs2, x)
f (generic function with 1 method)

julia> @btime f($img);
  5.667 μs (0 allocations: 0 bytes)

julia> @btime Zygote.gradient(f, $img);
  21.292 μs (3 allocations: 392.09 KiB)

julia> rule = build_rrule(f, img);

julia> @btime Mooncake.value_and_gradient!!($rule, f, $img);
  630.542 μs (5 allocations: 392.41 KiB)

julia> @btime Mooncake.__value_and_gradient!!($rule, zero_codual(f), $(zero_codual(img)));
  621.708 μs (0 allocations: 0 bytes)

julia> function Mooncake.rrule!!(::Mooncake.CoDual{typeof(sum)}, ::Mooncake.CoDual{typeof(abs2)}, x::Mooncake.CoDual{<:Array{P}}) where {P<:Mooncake.IEEEFloat}
           function sum_abs2_pb(dy::P)
               x.dx .+= (2 * dy) .* x.x
               return NoRData(), NoRData(), NoRData()
           end
           return Mooncake.zero_fcodual(sum(abs2, x.x)), sum_abs2_pb
       end

julia> Mooncake.@is_primitive Mooncake.DefaultCtx Tuple{typeof(sum), typeof(abs2), Array{<:Mooncake.IEEEFloat}}

julia> rule2 = build_rrule(f, img);

julia> @btime Mooncake.__value_and_gradient!!($rule2, Mooncake.zero_codual(f), $(Mooncake.zero_codual(img)));
  37.708 μs (0 allocations: 0 bytes)  # on Julia 1.10, I got 22.375 μs

(jl_MaAIiw) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_MaAIiw/Project.toml`
  [da2b9cff] Mooncake v0.4.36

julia> VERSION
v"1.11.0"

If, for example, you were doing sum(sin, x) rather than sum(abs2, x), you'd see that Enzyme and Mooncake have quite similar performance

Ok, I tried it. With a slightly harder function than sin, for force AD to get its hands dirty.

Here I think Zygote is using ForwardDiff, because that's what it does. (I don't know whether Enzyme & Mooncake switch to forward mode for broadcasting.)

using Zygote, Enzyme, Mooncake

img = rand(Float32, 28, 28, 1, 128);

red(x) = sum(x -> atan(x, 2f0), x);  # sum(f,x) which is is certain not to have an existing rule
red2(x) = sum(atan.(x, 2f0));  # implementation using broadcast instead

let
    g1 = @btime Zygote.gradient(red, $img)   # red(x)=sum(sin,x) allocates less memory here, 1.15 MiB
    g2 = @btime Zygote.gradient(red2, $img)  # also here, 1.53 MiB
    g2[1][1:3]
end
#  min 510.625 μs, mean 647.758 μs (39 allocations, 1.92 MiB)
#  min 645.250 μs, mean 855.522 μs (52 allocations, 2.30 MiB)

let
    g1 = @btime Enzyme.gradient(Reverse, red, $img)  # this is the only large change with red(x)=sum(sin,x) instead, -> 290.916 μs
    g2 = @btime Enzyme.gradient(Reverse, red2, $img)
    g2[1][1:3]
end
#   min 77.625 μs, mean 104.632 μs (3 allocations, 392.11 KiB)
#   min 939.375 μs, mean 1.051 ms (8 allocations, 1.15 MiB)

let
    rule = build_rrule(red, img);  # this step min 551.914 ns, mean 603.759 ns (10 allocations, 488 bytes) 
    g1 = @btime Mooncake.value_and_gradient!!($rule, $red, $img);
    rule2 = build_rrule(red2, img);
    g2 =  @btime Mooncake.value_and_gradient!!($rule2, $red2, $img);
    g2[2][2][1:3]
end
#   min 926.291 μs, mean 1.023 ms (5 allocations, 392.44 KiB)
#   min 2.521 ms, mean 2.672 ms (13 allocations, 1.15 MiB)
willtebbutt commented 2 days ago

Sure, in a sense Zygote is cheating because it has a rule for exactly x -> sum(abs2, x) and that's the whole function. Almost any AD system could be similarly engineered to beat the benchmark. Maybe the bigger question is how wide you eventually picture Mooncake's rules spreading -- it would certainly be ideal to need fewer than Zygote does, replacing hand-written with automatic. If that is the goal, then tests like this are how you measure progress.

Agree with the points you make here. Regarding rules -- yes, I definitely anticipate requiring many fewer rules in the short to medium term. Also agree re tracking performance -- if you take a look a Mooncake PR (eg. https://github.com/compintell/Mooncake.jl/pull/362) you'll see a table of benchmark numbers that get automatically printed into the PR -- they track the AD time / primal time for Mooncake and some others for a range of functions. In particular you'll see results for various variants of sum and kron which have really quite poor numbers. These are some near-worst case (I believe) numbers, and so they're roughly an upper bound on how poorly we should expect Mooncake to perform on loops written in Julia.

What's the vision here? Do you think that, long-term, operations like NNlib.relu.(x) ought to have custom rules defined?

Long-term I would hope not, but in the short-term they're going to continue to need them (sadly).

Regarding the numbers you're seeing: I see similar numbers for Mooncake locally. Enzyme is doing something truly impressive to get min 77.625 μs, because the primal takes roughly 400 microseconds to run (on my machine), so it must have managed to completely eliminate the primal computation in this case. I assume this is its activity analysis doing its job.

For context, Mooncake doesn't have any rules for higher order functions, and doesn't have any special tricks to make use of forwards-mode AD in situations where it might be advantageous to do so, so what's you're seeing here is just Mooncake doing reverse-mode AD on Julia IR. That is, it's ADing this just with a rule for atan and low-level rules for array operations (arrayref / arrayset on 1.10, and the Memory and MemoryRef equivalents on 1.11). With that in mind I'm pretty pleased with the performance, but I acknowledge it's got a way to go before it's at a level that people would consider "good".

Moving forwards, is there any chance of the Flux maintainers publishing a set of functions / models that they say "this is the core of the library, and is the bit that really needs to be fast"? In the short term I am quite interested in ensuring people can get good performance using Mooncake for fairly standard DL models, even if the performance drops off a bit for less standard stuff.

mcabbott commented 2 days ago

Enzyme is doing something truly impressive to get min 77.625 μs

Indeed, quicker than the primal for me too. Very clever but probably not the number we want. ReverseWithPrimal makes it slower than Zygote.

is there any chance of the Flux maintainers publishing a set of functions / models that they say "this is the core of the library, and is the bit that really needs to be fast"?

We're not very systematic, but these two are the simplest cases in the model zoo. I have not tried any GPU models.

https://github.com/FluxML/Flux.jl/pull/2471 wants to make it super-easy to try Enzyme instead of Zygote. It would be nice to do something similar for Mooncake. Since both (I believe) are happiest when the gradient is allocated up front, making some comparably easy way to do that might be nice.

willtebbutt commented 2 days ago

https://github.com/FluxML/Flux.jl/pull/2471 wants to make it super-easy to try Enzyme instead of Zygote. It would be nice to do something similar for Mooncake. Since both (I believe) are happiest when the gradient is allocated up front, making some comparably easy way to do that might be nice.

Correct -- preallocation of gradients storage will tend to give the best performance.

mcabbott commented 2 days ago

So the ideal Flux would be for something like Enzyme's Duplicated to be available:

struct MoonPair{X,DX}; x::X; dx::DX; end
MoonPair(x) = MoonPair(x, Mooncake.zero_codual(x))

Thinking about extensions etc, it would actually be simplest for Flux to own this. But perhaps that's weird.

Maybe this is the wrong thread for such interface discussion, though. Maybe https://github.com/FluxML/Flux.jl/pull/2471 is the right one?