Open oxinabox opened 4 years ago
I agree that the ChainRulesTestUtils approach is better than gradcheck
(I have adapted the rrules_test
for Zygote for a package before). To aid in diagnosing, what ended up being the fix in #366 that got the power series tests to pass again?
Fixed it so inputs to ChainRules were conjugated, which correct results for sin(::Complex)
I'm not certain if this is the cause, but it's a little tricky to test these power series functions using FiniteDifferences. Here's a simple example:
julia> using Random, Zygote, FiniteDifferences
julia> Random.seed!(42);
julia> _fdm = central_fdm(5, 1; adapt=5);
julia> seed = randn(3, 3)
3×3 Array{Float64,2}:
-0.556027 -0.299484 -0.468606
-0.444383 1.77786 0.156143
0.0271553 -1.1449 -2.64199
julia> A = Symmetric(randn(3, 3))
3×3 Symmetric{Float64,Array{Float64,2}}:
1.00331 0.518149 -0.886205
0.518149 1.49138 0.684565
-0.886205 0.684565 -1.59058
julia> _, zpb = Zygote.pullback(exp, A);
julia> zg = zpb(seed)[1]
3×3 Array{Float64,2}:
-2.11602 0.648903 0.634618
-0.680315 7.49855 0.822502
0.54197 -1.24868 -1.01659
julia> ng = conj.(FiniteDifferences.j′vp(_fdm, exp, seed, A))[1]
3×3 Symmetric{Float64,Array{Float64,2}}:
-2.11602 -0.0314117 1.17659
-0.0314117 7.49855 -0.426182
1.17659 -0.426182 -1.01659
The problem here is that FiniteDifferences constraints the j′vp to have the same type as the input, hence it makes it symmetric (see https://github.com/JuliaDiff/FiniteDifferences.jl/pull/76#issuecomment-617784173). Zygote has no such constraint (I think the FD approach essentially gives us Zygote's intended adjoint followed by a projection to the tangent space to the manifold defined by the constraint, though it's not generally the case that elements of the tangent space can be represented as points on the manifold). If you continue pulling back the adjoint through a Symmetric
call, then the two agree.
julia> _, zpb = Zygote.pullback(exp ∘ Symmetric, collect(A));
julia> zg = zpb(seed)[1]
3×3 Array{Float64,2}:
-2.11602 -0.0314117 1.17659
0.0 7.49855 -0.426182
0.0 0.0 -1.01659
julia> ng = conj.(FiniteDifferences.j′vp(_fdm, exp ∘ Symmetric, seed, collect(A)))[1]
3×3 Array{Float64,2}:
-2.11602 -0.0314117 1.17659
-1.02408e-15 7.49855 -0.426182
-1.02408e-15 -1.02408e-15 -1.01659
right, so do you think that means all is well?
right, so do you think that means all is well?
So far so good: https://gist.github.com/sethaxen/fa67e541c4a2a5e773b475349ed87fb9.
Still a couple of edge cases to check.
This came up in #366. And it has proved tricky to debug. AFAICT there were two issues, one from ChainRules directly (I forgot to put the adjoint for
abs(::Complex)
back in, fixed now). But secondly, what I think is an error in the actual definition of the power series adjoints. Which is what this issue is aboutIt is a subtle thing, that
gradcheck
does not catch becausegradcheck
sums things, which results into the errors canceling, because they are almost symmetric (I think they are symmetric up to floating point errors). But something in the chainrules PR breaks the symmetry just enough to register.Further testing by using FiniteDifferences to look at the whole gradient of
A^p
as in https://gist.github.com/oxinabox/fad7bae6dc11a7f0de31eff8666656cd Shows Zygote#master giving identical outputs to the ChainRules branch. (Which means the symmetry break is not in the actual output we are testing but somewhere else)(However, not 100% trusting of FiniteDifferences with complex numbers right now. In particular as of today if using the https://github.com/JuliaDiff/FiniteDifferences.jl/pull/76 branch it gets the answer wrong in the same way shown in https://github.com/JuliaDiff/FiniteDifferences.jl/pull/76#issuecomment-617277713)
More broadly, this is a great example of things that can go wrong with Zygotes current method of gradient checks. Moving forward, ChainRules has ChainRulesTestUtils, which uses more complete finite difference based tests. And adjoints implemented based on ChainRulesCore can use that. It is probably still worth fixing Zygotes grad-tests
Quoting relevant bits from #366
@oxinabox
@sethaxen
@antoine-levitt