FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Adjoint of exp inaccurate for complex matrices #348

Closed sethaxen closed 5 years ago

sethaxen commented 5 years ago

It's clear that these adjoints aren't supported from this line https://github.com/FluxML/Zygote.jl/blob/7dff4155f96b6675adb02916c7ad262e855abe1d/src/lib/array.jl#L416 which throws away the complex part of the adjoint, but there's another issue that persists whether real is used or not (using ngradient from the tests):

julia> using Zygote, LinearAlgebra, Random
julia> Random.seed!(42);
julia> a, b = randn(3,3), randn(3,3);
julia> function f(a, b)
           c = complex.(a, b)
           ec = exp(c)
           return sum(real.(ec)) + sum(imag.(ec))
       end
f (generic function with 1 method)
julia> da, db = Zygote.gradient(f, a, b);
julia> Δa, Δb = ngradient(f, a, b);
julia> da
3×3 Array{Float64,2}:
 -0.583189  -2.53525    0.0343655 
 -1.67204   -6.1851    -0.00224579
 -0.287775  -0.777853  -0.175802  
julia> db
3×3 Array{Float64,2}:
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0
julia> Δa
3×3 Array{Float64,2}:
 -0.517168  -0.0674414  -0.561498 
  0.536092   1.99185    -0.0965421
 -0.022163   1.14928    -0.312812 
julia> Δb
3×3 Array{Float64,2}:
 -0.583189  -2.53525    0.0343654 
 -1.67204   -6.1851    -0.00224572
 -0.287775  -0.777853  -0.175802  

Unsurprisingly, db is 0, but then da ≈ Δb. If you remove the real, you also find that db ≈ Δa. But if f zeros out b, then the correct da is returned.

I've been trying to work out the changes that would be necessary to fix this, but the code cites this paper, which as far as I can tell only gives the forward mode derivative of exp. I'm not certain how this reverse-mode implementation was derived from that. Can anyone shed any light on this?

sethaxen commented 5 years ago

I still don't understand how it was derived, but it looks like this is one of those adjoint things. Simply conjugating the input and output adjoints in Theano's implementation produces the correct adjoint for Zygote. Will submit a PR.