EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
422 stars 58 forks source link

Autodiff for `A -> A * A'` does not give hermitian result for complex `A` #1456

Open simsurace opened 1 month ago

simsurace commented 1 month ago

This was found while writing tests for #1307, where the function below is composed with cholesky:

square(A) = A * adjoint(A)

A = rand(ComplexF64, 5, 5)
ishermitian(square(A)) # true

dA = rand(ComplexF64, 5, 5)
S, dS = autodiff(Forward, square, Duplicated, Duplicated(A, dA))
S ≈ square(A) # true
ishermitian(S) # false

test_forward(square, Duplicated, (A, Duplicated)) # passes

So somehow the forward mode does not produce the same result as the function for square, but this is not caught by EnzymeTestUtils.

Something similar for reverse mode

square(A) = A * adjoint(A)
square!(S, A) = (mul!(S, A, adjoint(A)); return nothing)

A = rand(ComplexF64, 5, 5)
dA = zeros(ComplexF64, 5, 5)
S = zeros(ComplexF64, 5, 5)
dS = ones(ComplexF64, 5, 5)

autodiff(Reverse, square!, Const, Duplicated(S, dS), Duplicated(A, dA))
S ≈ square(A) # true
ishermitian(S) # false

test_reverse(square!, Const, (S, Duplicated), (A, Duplicated)) # passes
wsmoses commented 1 month ago

Looks like its hermition up to floating point-level precision.

julia> square(A)
5×5 Matrix{ComplexF64}:
 2.70575+0.0im        2.08828-0.483085im  2.06124+0.320394im   1.98721+0.052077im   2.05039+0.0833814im
 2.08828+0.483085im   2.44817+0.0im       2.11651+1.24836im     2.0555+0.458592im   2.04032+0.926832im
 2.06124-0.320394im   2.11651-1.24836im   3.83884+0.0im        2.57611-0.0681123im  2.66486-0.319575im
 1.98721-0.052077im    2.0555-0.458592im  2.57611+0.0681123im  2.99428+0.0im        2.49473+0.129998im
 2.05039-0.0833814im  2.04032-0.926832im  2.66486+0.319575im   2.49473-0.129998im   3.13973+0.0im

julia> S
5×5 Matrix{ComplexF64}:
 2.70575+3.6097e-17im  2.08828-0.483085im     2.06124+0.320394im     1.98721+0.052077im    2.05039+0.0833814im
 2.08828+0.483085im    2.44817-6.02588e-18im  2.11651+1.24836im       2.0555+0.458592im    2.04032+0.926832im
 2.06124-0.320394im    2.11651-1.24836im      3.83884+4.38384e-18im  2.57611-0.0681123im   2.66486-0.319575im
 1.98721-0.052077im     2.0555-0.458592im     2.57611+0.0681123im    2.99428+4.6523e-17im  2.49473+0.129998im
 2.05039-0.0833814im   2.04032-0.926832im     2.66486+0.319575im     2.49473-0.129998im    3.13973+1.95764e-17im

julia> square(A)-S
5×5 Matrix{ComplexF64}:
          0.0-3.6097e-17im            0.0+0.0im                   0.0-5.55112e-17im   2.22045e-16-3.46945e-17im  -4.44089e-16+1.249e-16im
          0.0+0.0im                   0.0+6.02588e-18im           0.0+0.0im          -4.44089e-16-5.55112e-17im   4.44089e-16+0.0im
          0.0+5.55112e-17im           0.0+0.0im           4.44089e-16-4.38384e-18im           0.0-1.11022e-16im  -4.44089e-16-5.55112e-17im
  2.22045e-16+3.46945e-17im  -4.44089e-16+5.55112e-17im           0.0+1.11022e-16im   4.44089e-16-4.6523e-17im            0.0-2.77556e-17im
 -4.44089e-16-1.249e-16im     4.44089e-16+0.0im          -4.44089e-16+5.55112e-17im           0.0+2.77556e-17im           0.0-1.95764e-17im

Per the warning

 ┌ Warning: Using fallback BLAS replacements for (["zgemm_64_", "zherk_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/kqxyC/src/utils.jl:59

this is known that the implementation of zgemm will be replaced by a different implementation, that apparently here results in some numeric differences within a reasonable tolerance.

wsmoses commented 1 month ago

I actually think the behavior here is reasonable, and the test utils properly only check within a tolerance.

@simsurace is there a reason why this is problematic?

simsurace commented 1 month ago

Ah, makes sense. What is needed to not have to rely on these fallbacks? At least for forward mode, what is the reason for not just calling the same function that is being passed? I understand that there must be a fallback to generate LLVM from to then compute the derivative, but the primal could still use the non-fallback or would that lead to problems?

Of course one can work around those but the composition of this function with something relying on the intermediate result being hermitian currently fails.

wsmoses commented 1 month ago

So now for forward mode you don't have it for complex, but for reals it won't use the fallback now