Open avik-pal opened 1 month ago

avik-pal commented 1 month ago
using NNlib, Enzyme, Zygote

gelu_act(x) = sum(abs2, gelu.(x))

x = randn(Float32, 32, 32, 1, 32)

@btime Enzyme.gradient(Reverse, $gelu_act, $x); # 1.298 ms (65 allocations: 386.95 KiB)

@btime Zygote.gradient($gelu_act, $x); # 735.499 μs (26 allocations: 384.75 KiB)

This might be somewhat unfair because gelu has rrule defined.

wsmoses commented 1 month ago

using BenchmarkTools, Enzyme, Zygote

gelu_act(x) = sum(abs2, sin.(x))

x = randn(Float32, 32, 32, 1, 32); 

Enzyme.gradient(Reverse, gelu_act, x); 

@btime Enzyme.gradient(Reverse, $gelu_act, $x); 

@btime Zygote.gradient($gelu_act, $x)
wsmoses commented 1 month ago

So the code generated here by the broadacst, before any Enzyme AD is actually quite awful. I wrote some primitive optimization passes to do a bit of cleanup (which may fix the runtiem activity), but still the indexing pattern is really bad.

@vchuravy any ideas what's happening here (besides it presumably now being > 3 dims so no specialization by Julia)

after simplification :
wsmoses commented 1 month ago

This is post my fix btw^

Post fix timings:

julia> @btime Enzyme.gradient(Reverse, $gelu_act, $x);
  927.718 μs (8 allocations: 384.28 KiB)

julia> @btime Zygote.gradient($gelu_act, $x)
  386.687 μs (38 allocations: 641.19 KiB)

The bigger issue rn imo is the fact that loop bounds aren't statically inferrable due to whatever that awful index math is. So as a result inner loops are caching the true iteration count, inside other loops, doing a bunch of unnecessary caching/etc. That's still not fixed.

Pre my fix timings are slower so minor fix does ~something for perf, but again index math is likely root cause. At least others won't have to deal with runtime activity though.

julia> Enzyme.gradient(Reverse, gelu_act, x);

julia> @btime Enzyme.gradient(Reverse, $gelu_act, $x);
  967.698 μs (8 allocations: 384.28 KiB)

julia> @btime Zygote.gradient($gelu_act, $x)
  377.638 μs (38 allocations: 641.19 KiB)
wsmoses commented 1 month ago

Module pre optimization: preopt.ll.txt