dfdx / Yota.jl

Reverse-mode automatic differentiation in Julia
MIT License
158 stars 12 forks source link

Fix rule for `getindex(::Tuple)` #116

Closed mcabbott closed 2 years ago

mcabbott commented 2 years ago

This wants to fix the following error:

julia> using Flux, Yota

julia> model = Chain(Dense(2=>3, tanh), Dense(3=>2));

julia> Yota.grad((m,x) -> sum(abs2, m(x)), model, randn(Float32, 2, 5))
ERROR: MethodError: no method matching zero(::Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}})

After:

julia> _, (_, dm, dx) = Yota.grad((m,x) -> sum(abs2, m(x)), model, randn(Float32, 2, 5));

julia> dm.layers[1].bias
3-element Vector{Float32}:
  0.37451237
 -0.008268073
 -0.29834867
codecov-commenter commented 2 years ago

Codecov Report

Merging #116 (6774742) into main (7bc3e67) will not change coverage. The diff coverage is 100.00%.

@@           Coverage Diff           @@
##             main     #116   +/-   ##
=======================================
  Coverage   85.41%   85.41%           
=======================================
  Files           9        9           
  Lines         473      473           
=======================================
  Hits          404      404           
  Misses         69       69           
Impacted Files Coverage Δ
src/helpers.jl 54.68% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 7bc3e67...6774742. Read the comment docs.

mcabbott commented 2 years ago

I'm surprised ChainRules didn't have a rule for this. See whether you think https://github.com/JuliaDiff/ChainRules.jl/pull/643 looks OK?

dfdx commented 2 years ago

JuliaDiff/ChainRules.jl#643 definitely looks like the way to go for getindex(::Tuple). If I don't miss anything, I guess the optimal solution is to merge that PR and modify this one to only keep the test, right?

Also, I only see getindex(::Array) and not getindex(::AbstractArray) in ChainRules. Do you know if it's intended?

mcabbott commented 2 years ago

I think one hope might have been that indexing of (say) a transposed array ultimately calls getindex(::Array), so a rule only for that might ultimately compose well. But taking a slice of A' then needs to wait until it resolves to scalar indexing of A, so the AD is going to have to handle many more operations. Having getindex(::AbstractArray) is probably a better fit for present AD, and ChainRules 1.0.

dfdx commented 2 years ago

Then I'm going to merge this PR, but not remove Yota-specific rules for getindex(::AbstractArray) yet.

I'll be happy to switch to CR for tuples once JuliaDiff/ChainRules.jl#643 is merged though.