Closed CarloLucibello closed 3 years ago
Can I take over this issue?
Please do :) Happy to review a PR resolving this!
Yes, sure, you are very welcome to
It would be great if you could do it in two parts:
repeat
rules from Zygote
.Doing it in two steps will make it really easy for us to review :)
Bump.
Please note that the rules for Zygote are incomplete: https://github.com/FluxML/Zygote.jl/issues/950
PRs welcome! The docs should help introduce the idea of rrule
s and how to write them, for anyone who want to contribute :) (https://juliadiff.org/ChainRulesCore.jl/dev/#frule-and-rrule)
I've recommended this issue to @Wimmerer whom I will (likely) co-mentor for his GraphBLAS x ChainRules
GSoC project
I recently had a go at implementing repeat
on GPU and the corresponding adjoint, and on the more recent versions of Julia (at least 1.6, maybe 1.5, a bit uncertain) it seems more natural to define an adjoint for the methods present in Base._RepeatInnerOuter
rather than Base.repeat
directly. Example: https://github.com/JuliaGPU/CUDA.jl/issues/177#issuecomment-838311843. The benefit of defining adjoints for repeat_inner_outer
is that everything is "standardized", e.g. inner
and outer
is always a Tuple
or Nothing
, never an integer or w/e. It also makes it somewhat annoying to define a rule for repeat
in some other package, since Base.repeat
is hit before all this convenience kicks in, meaning that you have to manually add in stuff like _RepeatInnerOuter.check(arr, inner, outer)
, replicating the body of _RepeatInnerOuter.repeat
.
But this is not backwards compatible (and going by the _
in front of the submodule in Base, I'm guessing the impls could be changed in the future?) :confused: Sooo I'm not necessarily advocating for using this! I'm just pointing out it's existence:)
More importantly, the current implementation in Zygote.jl is unfortunately quite slow :confused: E.g.
using BenchmarkTools
xs = collect(1:100)
y, dy = Zygote.pullback(xs) do x
repeat(x, inner=(7, ))
end
@benchmark $dy($(similar(y) .= 1))
results in
BenchmarkTools.Trial:
memory estimate: 99.31 KiB
allocs estimate: 2101
--------------
minimum time: 138.135 μs (0.00% GC)
median time: 143.639 μs (0.00% GC)
mean time: 162.003 μs (8.45% GC)
maximum time: 16.224 ms (98.93% GC)
--------------
samples: 10000
evals/sample: 1
for the current implementation in Zygote.jl, and from the one in my comment above (single-threaded version, not using any parallelism here):
BenchmarkTools.Trial:
memory estimate: 1.75 KiB
allocs estimate: 25
--------------
minimum time: 6.816 μs (0.00% GC)
median time: 7.336 μs (0.00% GC)
mean time: 7.584 μs (0.00% GC)
maximum time: 20.593 μs (0.00% GC)
--------------
samples: 10000
evals/sample: 5
And a more complex example:
xs = repeat(collect(1:100), inner=(2, 3, 4))
y, dy = Zygote.pullback(xs) do x
repeat(x, inner=(7, 1, 10))
end
Current impl:
memory estimate: 25.65 MiB
allocs estimate: 504002
--------------
minimum time: 44.396 ms (0.00% GC)
median time: 46.164 ms (0.00% GC)
mean time: 49.370 ms (5.52% GC)
maximum time: 62.855 ms (19.16% GC)
--------------
samples: 102
evals/sample: 1
"Naive" impl using KernelAbstractions:
BenchmarkTools.Trial:
memory estimate: 20.12 KiB
allocs estimate: 26
--------------
minimum time: 4.912 ms (0.00% GC)
median time: 5.061 ms (0.00% GC)
mean time: 5.122 ms (0.00% GC)
maximum time: 8.344 ms (0.00% GC)
--------------
samples: 976
evals/sample: 1
Including outer
too:
xs = repeat(collect(1:100), inner=(2, 3, 4))
y, dy = Zygote.pullback(xs) do x
repeat(x, inner=(7, 1, 10), outer=(4, 1, 2))
end
@benchmark $dy($(similar(y) .= 1))
Current impl:
BenchmarkTools.Trial:
memory estimate: 210.44 MiB
allocs estimate: 4382402
--------------
minimum time: 380.338 ms (3.36% GC)
median time: 394.218 ms (6.62% GC)
mean time: 393.056 ms (5.65% GC)
maximum time: 400.481 ms (3.24% GC)
--------------
samples: 13
evals/sample: 1
"Naive" impl using KA:
BenchmarkTools.Trial:
memory estimate: 20.12 KiB
allocs estimate: 26
--------------
minimum time: 39.975 ms (0.00% GC)
median time: 40.797 ms (0.00% GC)
mean time: 41.031 ms (0.00% GC)
maximum time: 47.989 ms (0.00% GC)
--------------
samples: 122
evals/sample: 1
It might seem a bit contrived, but I have examples in a project I'm working on where I have to choose between repeat
before or after applying some function f
(which is just a bunch of linear algebra) with length(f(x)) == length(x)
. Intuitively you'd expect repeat
after f
to be cheaper since both repeat
and its adjoint requires at most m
reads and m
writes (m = length(f(x))
) while f
clearly needs to at least perform the same (unless there's a sparsity pattern, but the impl could easily be specialized to this case too if wanted). But when using this together with Zygote.jl (and the current impl of adjoint of repeat
) in my problem it turned out not be the case :confused:
I'm not suggesting ChainRules.jl should depend on KernelAbstractions.jl here though! I'm just pointing out that just copy-pasting the current implementation maybe isn't the optimal approach. I bet we'd get the same perf using a @generated
function inplace of @kernel
and just use the same pattern (without the kernel-launching part). And I'm happy to help out if needed!
Zygote has several different implementations for the repeat gradient, https://github.com/FluxML/Zygote.jl/blob/2468e2f1f787e42e9c3d0aa93cec58a0488f8573/src/lib/array.jl#L149-L173 , and the one you're timing seems to be 300x slower than the other one:
julia> xs = collect(1:100)
julia> y_in7, back_in7 = Zygote.pullback(xs) do x # same as above
repeat(x, inner=(7, ))
end;
julia> @btime back_in7($(ones(size(y_in7))));
144.166 μs (4202 allocations: 132.14 KiB)
julia> y_out7, back_out7 = Zygote.pullback(xs) do x # also uses slow implementation
repeat(x, outer=(7, ))
end;
julia> @btime back_out7($(ones(size(y_out7))));
144.167 μs (4202 allocations: 132.14 KiB)
julia> y_vec7, back_vec7 = Zygote.pullback(xs) do x
repeat(x, 7)
end;
julia> @btime back_vec7($(ones(size(y_vec7)))); # uses reshape & sum
545.857 ns (13 allocations: 1.19 KiB)
julia> y_out7 ≈ y_vec7
true
julia> @btime repeat($xs, inner=(7,)); # forward pass for comparison
453.046 ns (1 allocation: 5.62 KiB)
julia> @btime repeat($xs, 7);
281.323 ns (1 allocation: 5.62 KiB)
The fast one just uses reshape
& sum
. That ought to be possible for all cases, I think. And the result would then probably work well on CuArray
s, too.
Edit: here's a version you can try out: https://gist.github.com/mcabbott/80ac43cca3bee8f57809155a5240519f . The above times become (on the same machine):
julia> @btime back_in7($(ones(size(y_in7)))); # with _repeat
572.973 ns (11 allocations: 1.27 KiB)
julia> @btime back_out7($(ones(size(y_out7)))); # with _repeat
965.609 ns (11 allocations: 1.27 KiB)
julia> @btime back_vec7($(ones(size(y_vec7))));
465.909 ns (10 allocations: 1.22 KiB)
(But also, why are you using repeat
in something you want to be fast? It always seems wasteful to make many copies of the same data just to make some array the right shape, especially if you aren't then going to mutate it... can't you use broadcasting to do it lazily?)
Doesn't https://github.com/JuliaDiff/ChainRules.jl/pull/460 close this? Note that https://github.com/FluxML/Zygote.jl/issues/950 works now.
Some of them can be ported over from Zygote.
cf. https://github.com/FluxML/Zygote.jl/issues/906 https://github.com/FluxML/Zygote.jl/blob/956cbcf3c572c0eb09c146189bb38b1b434634ff/src/lib/array.jl#L130