JuliaDiff / ChainRules.jl

forward and reverse mode automatic differentiation primitives for Julia Base + StdLibs
Other
435 stars 89 forks source link

rules for `repeat` are missing #383

Closed CarloLucibello closed 3 years ago

CarloLucibello commented 3 years ago

Some of them can be ported over from Zygote.

cf. https://github.com/FluxML/Zygote.jl/issues/906 https://github.com/FluxML/Zygote.jl/blob/956cbcf3c572c0eb09c146189bb38b1b434634ff/src/lib/array.jl#L130

Dsantra92 commented 3 years ago

Can I take over this issue?

willtebbutt commented 3 years ago

Please do :) Happy to review a PR resolving this!

mzgubic commented 3 years ago

Yes, sure, you are very welcome to

willtebbutt commented 3 years ago

It would be great if you could do it in two parts:

  1. port across the current repeat rules from Zygote.
  2. add the missing rules.

Doing it in two steps will make it really easy for us to review :)

cossio commented 3 years ago

Bump.

Please note that the rules for Zygote are incomplete: https://github.com/FluxML/Zygote.jl/issues/950

nickrobinson251 commented 3 years ago

PRs welcome! The docs should help introduce the idea of rrules and how to write them, for anyone who want to contribute :) (https://juliadiff.org/ChainRulesCore.jl/dev/#frule-and-rrule)

mzgubic commented 3 years ago

I've recommended this issue to @Wimmerer whom I will (likely) co-mentor for his GraphBLAS x ChainRules GSoC project

torfjelde commented 3 years ago

I recently had a go at implementing repeat on GPU and the corresponding adjoint, and on the more recent versions of Julia (at least 1.6, maybe 1.5, a bit uncertain) it seems more natural to define an adjoint for the methods present in Base._RepeatInnerOuter rather than Base.repeat directly. Example: https://github.com/JuliaGPU/CUDA.jl/issues/177#issuecomment-838311843. The benefit of defining adjoints for repeat_inner_outer is that everything is "standardized", e.g. inner and outer is always a Tuple or Nothing, never an integer or w/e. It also makes it somewhat annoying to define a rule for repeat in some other package, since Base.repeat is hit before all this convenience kicks in, meaning that you have to manually add in stuff like _RepeatInnerOuter.check(arr, inner, outer), replicating the body of _RepeatInnerOuter.repeat.

But this is not backwards compatible (and going by the _ in front of the submodule in Base, I'm guessing the impls could be changed in the future?) :confused: Sooo I'm not necessarily advocating for using this! I'm just pointing out it's existence:)

More importantly, the current implementation in Zygote.jl is unfortunately quite slow :confused: E.g.

using BenchmarkTools

xs = collect(1:100)
y, dy = Zygote.pullback(xs) do x
    repeat(x, inner=(7, ))
end

@benchmark $dy($(similar(y) .= 1))

results in

BenchmarkTools.Trial: 
  memory estimate:  99.31 KiB
  allocs estimate:  2101
  --------------
  minimum time:     138.135 μs (0.00% GC)
  median time:      143.639 μs (0.00% GC)
  mean time:        162.003 μs (8.45% GC)
  maximum time:     16.224 ms (98.93% GC)
  --------------
  samples:          10000
  evals/sample:     1

for the current implementation in Zygote.jl, and from the one in my comment above (single-threaded version, not using any parallelism here):

BenchmarkTools.Trial: 
  memory estimate:  1.75 KiB
  allocs estimate:  25
  --------------
  minimum time:     6.816 μs (0.00% GC)
  median time:      7.336 μs (0.00% GC)
  mean time:        7.584 μs (0.00% GC)
  maximum time:     20.593 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     5

And a more complex example:

xs = repeat(collect(1:100), inner=(2, 3, 4))
y, dy = Zygote.pullback(xs) do x
    repeat(x, inner=(7, 1, 10))
end

Current impl:

  memory estimate:  25.65 MiB
  allocs estimate:  504002
  --------------
  minimum time:     44.396 ms (0.00% GC)
  median time:      46.164 ms (0.00% GC)
  mean time:        49.370 ms (5.52% GC)
  maximum time:     62.855 ms (19.16% GC)
  --------------
  samples:          102
  evals/sample:     1

"Naive" impl using KernelAbstractions:

BenchmarkTools.Trial: 
  memory estimate:  20.12 KiB
  allocs estimate:  26
  --------------
  minimum time:     4.912 ms (0.00% GC)
  median time:      5.061 ms (0.00% GC)
  mean time:        5.122 ms (0.00% GC)
  maximum time:     8.344 ms (0.00% GC)
  --------------
  samples:          976
  evals/sample:     1

Including outer too:

xs = repeat(collect(1:100), inner=(2, 3, 4))
y, dy = Zygote.pullback(xs) do x
    repeat(x, inner=(7, 1, 10), outer=(4, 1, 2))
end

@benchmark $dy($(similar(y) .= 1))

Current impl:

BenchmarkTools.Trial: 
  memory estimate:  210.44 MiB
  allocs estimate:  4382402
  --------------
  minimum time:     380.338 ms (3.36% GC)
  median time:      394.218 ms (6.62% GC)
  mean time:        393.056 ms (5.65% GC)
  maximum time:     400.481 ms (3.24% GC)
  --------------
  samples:          13
  evals/sample:     1

"Naive" impl using KA:

BenchmarkTools.Trial: 
  memory estimate:  20.12 KiB
  allocs estimate:  26
  --------------
  minimum time:     39.975 ms (0.00% GC)
  median time:      40.797 ms (0.00% GC)
  mean time:        41.031 ms (0.00% GC)
  maximum time:     47.989 ms (0.00% GC)
  --------------
  samples:          122
  evals/sample:     1

It might seem a bit contrived, but I have examples in a project I'm working on where I have to choose between repeat before or after applying some function f (which is just a bunch of linear algebra) with length(f(x)) == length(x). Intuitively you'd expect repeat after f to be cheaper since both repeat and its adjoint requires at most m reads and m writes (m = length(f(x))) while f clearly needs to at least perform the same (unless there's a sparsity pattern, but the impl could easily be specialized to this case too if wanted). But when using this together with Zygote.jl (and the current impl of adjoint of repeat) in my problem it turned out not be the case :confused:

I'm not suggesting ChainRules.jl should depend on KernelAbstractions.jl here though! I'm just pointing out that just copy-pasting the current implementation maybe isn't the optimal approach. I bet we'd get the same perf using a @generated function inplace of @kernel and just use the same pattern (without the kernel-launching part). And I'm happy to help out if needed!

mcabbott commented 3 years ago

Zygote has several different implementations for the repeat gradient, https://github.com/FluxML/Zygote.jl/blob/2468e2f1f787e42e9c3d0aa93cec58a0488f8573/src/lib/array.jl#L149-L173 , and the one you're timing seems to be 300x slower than the other one:

julia> xs = collect(1:100)

julia> y_in7, back_in7 = Zygote.pullback(xs) do x  # same as above
           repeat(x, inner=(7, ))
       end;

julia> @btime back_in7($(ones(size(y_in7))));
  144.166 μs (4202 allocations: 132.14 KiB)

julia> y_out7, back_out7 = Zygote.pullback(xs) do x  # also uses slow implementation
           repeat(x, outer=(7, ))
       end;

julia> @btime back_out7($(ones(size(y_out7))));
  144.167 μs (4202 allocations: 132.14 KiB)

julia> y_vec7, back_vec7 = Zygote.pullback(xs) do x
           repeat(x, 7)
       end;

julia> @btime back_vec7($(ones(size(y_vec7))));  # uses reshape & sum
  545.857 ns (13 allocations: 1.19 KiB)

julia> y_out7 ≈ y_vec7
true

julia> @btime repeat($xs, inner=(7,));  # forward pass for comparison
  453.046 ns (1 allocation: 5.62 KiB)

julia> @btime repeat($xs, 7);
  281.323 ns (1 allocation: 5.62 KiB)

The fast one just uses reshape & sum. That ought to be possible for all cases, I think. And the result would then probably work well on CuArrays, too.

Edit: here's a version you can try out: https://gist.github.com/mcabbott/80ac43cca3bee8f57809155a5240519f . The above times become (on the same machine):

julia> @btime back_in7($(ones(size(y_in7))));  # with _repeat
  572.973 ns (11 allocations: 1.27 KiB)

julia> @btime back_out7($(ones(size(y_out7))));  # with _repeat
  965.609 ns (11 allocations: 1.27 KiB)

julia> @btime back_vec7($(ones(size(y_vec7))));
  465.909 ns (10 allocations: 1.22 KiB) 

(But also, why are you using repeat in something you want to be fast? It always seems wasteful to make many copies of the same data just to make some array the right shape, especially if you aren't then going to mutate it... can't you use broadcasting to do it lazily?)

cossio commented 3 years ago

Doesn't https://github.com/JuliaDiff/ChainRules.jl/pull/460 close this? Note that https://github.com/FluxML/Zygote.jl/issues/950 works now.