JuliaDiff / ChainRules.jl

forward and reverse mode automatic differentiation primitives for Julia Base + StdLibs
Other
435 stars 89 forks source link

cholesky decomposition #414

Open alfredjmduncan opened 3 years ago

alfredjmduncan commented 3 years ago

I'm getting incorrect results when working with the rrule for cholesky where A <: LinearAlgebra.HermOrSym

Passing the input matrix through Matrix fixes the issue. The mul! fix relates to this issue.

using Zygote, ChainRules,LinearAlgebra
# Example matrix
A = [2. -1. 0.0; -1. 2. -1.; 0. -1. 2. ]

import LinearAlgebra.mul!
LinearAlgebra.mul!(C, ::ChainRulesCore.ZeroTangent, ::Any, ::Any, b) = C *=b

# produces zeros
Zygote.jacobian(a -> cholesky(Hermitian(a)).L , A)[1]

# both of the following produce expected result
Zygote.jacobian(a -> cholesky(Matrix(Hermitian(a))).L , A)[1]
Zygote.jacobian(a -> cholesky(a).L , A)[1]
oxinabox commented 3 years ago

huh, I really though that definition for mul! would fix things correctly. I will have to look closely, Oh. I think it needs a . LinearAlgebra.mul!(C, ::ChainRulesCore.ZeroTangent, ::Any, ::Any, b) = C .*=b though that seems likely to be less efficient for bools (thyough it might still optimize out, so it might need a if)