FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Power Series adjoints are subtly wrong in a way gradcheck can't catch #608

Open oxinabox opened 4 years ago

oxinabox commented 4 years ago

This came up in #366. And it has proved tricky to debug. AFAICT there were two issues, one from ChainRules directly (I forgot to put the adjoint for abs(::Complex) back in, fixed now). But secondly, what I think is an error in the actual definition of the power series adjoints. Which is what this issue is about

It is a subtle thing, that gradcheck does not catch because gradcheck sums things, which results into the errors canceling, because they are almost symmetric (I think they are symmetric up to floating point errors). But something in the chainrules PR breaks the symmetry just enough to register.

Further testing by using FiniteDifferences to look at the whole gradient of A^p as in https://gist.github.com/oxinabox/fad7bae6dc11a7f0de31eff8666656cd Shows Zygote#master giving identical outputs to the ChainRules branch. (Which means the symmetry break is not in the actual output we are testing but somewhere else)

(However, not 100% trusting of FiniteDifferences with complex numbers right now. In particular as of today if using the https://github.com/JuliaDiff/FiniteDifferences.jl/pull/76 branch it gets the answer wrong in the same way shown in https://github.com/JuliaDiff/FiniteDifferences.jl/pull/76#issuecomment-617277713)

More broadly, this is a great example of things that can go wrong with Zygotes current method of gradient checks. Moving forward, ChainRules has ChainRulesTestUtils, which uses more complete finite difference based tests. And adjoints implemented based on ChainRulesCore can use that. It is probably still worth fixing Zygotes grad-tests


Quoting relevant bits from #366

@oxinabox

current issue: A^p for A being a symmetric matrix and p=-0.5 see this reproducer https://gist.github.com/oxinabox/fad7bae6dc11a7f0de31eff8666656cd

@sethaxen

I'm not certain what's going on with the matrix. If there's a bug, it'll probably be here:

https://github.com/FluxML/Zygote.jl/blob/713715faf0362f91b3a56e1967e3e29ba7740d52/src/lib/array.jl#L619-L627

The adjoint for the power is handled by

https://github.com/FluxML/Zygote.jl/blob/713715faf0362f91b3a56e1967e3e29ba7740d52/src/lib/array.jl#L646 and may be off by a conj.

I probably won't have a chance to check the math until the end of the week.

By comparison with the second adjoint in

https://github.com/FluxML/Zygote.jl/blob/4f7a2eeeb98eabf43dd70c8698aadb0ec6050c48/src/lib/number.jl#L33-L34

This line does look to be missing a conj:

https://github.com/FluxML/Zygote.jl/blob/713715faf0362f91b3a56e1967e3e29ba7740d52/src/lib/array.jl#L626

If this can't wait until the end of the week, I dropped my notes on the derivation here in case someone else wants to check the math or that the code is compatible with it: https://gist.github.com/sethaxen/000d164e515014fdda70601be1ecfb56.

Disclaimer: I don't know if this is the final version; it's just what I found on my machine.

@antoine-levitt

Sorry, can't resist the nitpicking since I've run into this exact problem recently : using a Taylor series to compute accurately terms like (f(a) - f(b)) / (a-b) doesn't get you to O(eps) accuracy, it gets you sqrt(eps) when using a first order expansion. You can increase the order to n but then you get something like eps^(1-1/n) in the worst case. I don't know of any method that gives full accuracy here for a general function f.

I don't have anything better: I ended up with https://github.com/JuliaMolSim/DFTK.jl/blob/master/src/Smearing.jl#L44, with O(sqrt(eps)) accuracy using the first derivative (sorry for derailing this topic)

sethaxen commented 4 years ago

I agree that the ChainRulesTestUtils approach is better than gradcheck (I have adapted the rrules_test for Zygote for a package before). To aid in diagnosing, what ended up being the fix in #366 that got the power series tests to pass again?

oxinabox commented 4 years ago

Fixed it so inputs to ChainRules were conjugated, which correct results for sin(::Complex)

sethaxen commented 4 years ago

I'm not certain if this is the cause, but it's a little tricky to test these power series functions using FiniteDifferences. Here's a simple example:

julia> using Random, Zygote, FiniteDifferences

julia> Random.seed!(42);

julia> _fdm = central_fdm(5, 1; adapt=5);

julia> seed = randn(3, 3)
3×3 Array{Float64,2}:
 -0.556027   -0.299484  -0.468606
 -0.444383    1.77786    0.156143
  0.0271553  -1.1449    -2.64199

julia> A = Symmetric(randn(3, 3))
3×3 Symmetric{Float64,Array{Float64,2}}:
  1.00331   0.518149  -0.886205
  0.518149  1.49138    0.684565
 -0.886205  0.684565  -1.59058

julia> _, zpb = Zygote.pullback(exp, A);

julia> zg = zpb(seed)[1]
3×3 Array{Float64,2}:
 -2.11602    0.648903   0.634618
 -0.680315   7.49855    0.822502
  0.54197   -1.24868   -1.01659

julia> ng = conj.(FiniteDifferences.j′vp(_fdm, exp, seed, A))[1]
3×3 Symmetric{Float64,Array{Float64,2}}:
 -2.11602    -0.0314117   1.17659
 -0.0314117   7.49855    -0.426182
  1.17659    -0.426182   -1.01659

The problem here is that FiniteDifferences constraints the j′vp to have the same type as the input, hence it makes it symmetric (see https://github.com/JuliaDiff/FiniteDifferences.jl/pull/76#issuecomment-617784173). Zygote has no such constraint (I think the FD approach essentially gives us Zygote's intended adjoint followed by a projection to the tangent space to the manifold defined by the constraint, though it's not generally the case that elements of the tangent space can be represented as points on the manifold). If you continue pulling back the adjoint through a Symmetric call, then the two agree.

julia> _, zpb = Zygote.pullback(exp ∘ Symmetric, collect(A));

julia> zg = zpb(seed)[1]
3×3 Array{Float64,2}:
 -2.11602  -0.0314117   1.17659
  0.0       7.49855    -0.426182
  0.0       0.0        -1.01659

julia> ng = conj.(FiniteDifferences.j′vp(_fdm, exp ∘ Symmetric, seed, collect(A)))[1]
3×3 Array{Float64,2}:
 -2.11602      -0.0314117     1.17659
 -1.02408e-15   7.49855      -0.426182
 -1.02408e-15  -1.02408e-15  -1.01659
oxinabox commented 4 years ago

right, so do you think that means all is well?

sethaxen commented 4 years ago

right, so do you think that means all is well?

So far so good: https://gist.github.com/sethaxen/fa67e541c4a2a5e773b475349ed87fb9.

Still a couple of edge cases to check.