FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

custom adjoint fails for functions that take tuple as additional argument #214

Closed sipposip closed 5 years ago

sipposip commented 5 years ago

I am trying to write custom adjoints for FFTW functions (https://github.com/FluxML/Zygote.jl/issues/204). Some of these functions take an optional argument that can be a single number, a tuple or an array. It seems that custom adjoints have trouble with additional arguments to the differentiated function when these additional arguments are tuples (or arrays), even when the additional argument is compeltely discarded within the function Minimum example:

f(xs,dummy) = sum(2*xs )
Zygote.@adjoint function f(xs, dummy)
    return f(xs,dummy), function (Δ)
        return (2*Δ,)
    end
end
# works
Zygote.gradient((x)->f(x,1), [1,2,3])
# fails
Zygote.gradient((x)->f(x,(1,2)), [1,2,3])

the second fails with

ERROR: BoundsError: attempt to access (nothing, 2)
  at index [3]
Stacktrace:
 [1] getindex at ./tuple.jl:24 [inlined]
 [2] gradindex(::Tuple{Nothing,Int64}, ::Int64) at /home/sebastian/.julia/packages/Zygote/VeaFW/src/compiler/reverse.jl:13
 [3] #17 at ./REPL[9]:2 [inlined]
 [4] (::typeof(∂(getfield(Main, Symbol("##17#18"))())))(::Int64) at /home/sebastian/.julia/packages/Zygote/VeaFW/src/compiler/interface2.jl:0
 [5] (::getfield(Zygote, Symbol("##34#35")){typeof(∂(getfield(Main, Symbol("##17#18"))()))})(::Int64) at /home/sebastian/.julia/packages/Zygote/VeaFW/src/compiler/interface.jl:38
 [6] gradient(::Function, ::Array{Int64,1}, ::Vararg{Array{Int64,1},N} where N) at /home/sebastian/.julia/packages/Zygote/VeaFW/src/compiler/interface.jl:47
 [7] top-level scope at none:0
MikeInnes commented 5 years ago

This is intentional so that you don't accidentally drop gradients. Just provide a zero gradient for dummy explicitly with (2*Δ, nothing).

sipposip commented 5 years ago

Thanks, that solved it.