JuliaDiff / ChainRules.jl

forward and reverse mode automatic differentiation primitives for Julia Base + StdLibs
Other
435 stars 89 forks source link

Wrong call on `cholesky` `rrule` #611

Closed theogf closed 2 years ago

theogf commented 2 years ago

Identified in https://github.com/JuliaStats/PDMats.jl/issues/159

When calling the cholesky rrule, there is an error in the code at https://github.com/JuliaDiff/ChainRules.jl/blob/c5dbe030af390599848830ff43a5dffc04be69e2/src/rulesets/LinearAlgebra/factorization.jl#L526

Ū = ΔC.U will be NoTangent() since the Tangent of Cholesky object does not contain the field U. I think it should be Ū = ΔC.factors

sethaxen commented 2 years ago

Ū = ΔC.U will be NoTangent() since the Tangent of Cholesky object does not contain the field U.

Cholesky does not have the field U, but it does have the property U. Then there's this rrule for getproperty: https://github.com/JuliaDiff/ChainRules.jl/blob/c5dbe030af390599848830ff43a5dffc04be69e2/src/rulesets/LinearAlgebra/factorization.jl#L534-L553

Note that the current rules won't compose well if any downstream code accesses factors, but everything should work well if they access U or L. This should probably be improved though to look more like the rules for lu:

https://github.com/JuliaDiff/ChainRules.jl/blob/c5dbe030af390599848830ff43a5dffc04be69e2/src/rulesets/LinearAlgebra/factorization.jl#L80-L169

theogf commented 2 years ago

So this assumes that ΔC here is a Cholesky object right? https://github.com/JuliaDiff/ChainRules.jl/blob/c5dbe030af390599848830ff43a5dffc04be69e2/src/rulesets/LinearAlgebra/factorization.jl#L524 Which is not possible it is obviously a tangent : https://github.com/JuliaDiff/ChainRules.jl/blob/c5dbe030af390599848830ff43a5dffc04be69e2/src/rulesets/LinearAlgebra/factorization.jl#L514

sethaxen commented 2 years ago

So this assumes that ΔC here is a Cholesky object right?

I don't see anywhere it assumes that. The preceding line has the type annotation ΔC::Tangent. The (co)tangent for Cholesky will not be a Cholesky.

theogf commented 2 years ago

Right so ΔC is a Tangent! So ΔC.U will call getproperty here https://github.com/JuliaDiff/ChainRulesCore.jl/blob/2d75b4be102bb41ba3ac6df6dec8bb9617b20f0f/src/tangent_types/tangent.jl#L104 And since hasfield(NamedTuple, :U) is false, it will return NoTangent()...

sethaxen commented 2 years ago

T refers to the type of the backing, which will have fields for whatever keywords were passed to the Tangent constructor, which can be anything. e.g.

julia> using ChainRulesCore

julia> struct Cholesky
           factors
       end

julia> F = Cholesky(randn(5, 5));

julia> ΔF = Tangent{typeof(F)}(U=randn(5,5))
Tangent{Cholesky}(U = [-0.7059268385030435 -0.06675693167812352 … -0.2413203958757456 0.03638639429591188; -0.47788794725474393 -1.2764613865307353 … 1.2012293162255412 0.15961345415826841; … ; 0.6552274641328225 0.24813748236499192 … 1.3140512188252975 0.22398055650819915; 2.2445029073821035 -0.5210976005765643 … 1.6755604234385513 0.4181320724677388],)

julia> ΔF.factors
ZeroTangent()

julia> ΔF.U
5×5 Matrix{Float64}:
 -0.705927  -0.0667569  -1.42752    -0.24132  0.0363864
 -0.477888  -1.27646    -0.652307    1.20123  0.159613
 -0.449378  -0.696023   -1.21753    -1.20449  0.23459
  0.655227   0.248137    0.0819825   1.31405  0.223981
  2.2445    -0.521098   -0.145292    1.67556  0.418132

Note that if this wasn't the current behavior, the rule for cholesky would always produce a ZeroTangent(), which it doesn't. Is there a particular surprising behavior/error you encountered that led to this issue?

devmotion commented 2 years ago

I think that https://github.com/JuliaDiff/ChainRules.jl/blob/c5dbe030af390599848830ff43a5dffc04be69e2/src/rulesets/LinearAlgebra/factorization.jl#L526 causes some of the test errors in https://github.com/FluxML/Zygote.jl/pull/1114 (e.g., in https://github.com/FluxML/Zygote.jl/runs/6021203148?check_suite_focus=true#step:6:192).

Due to being a ZeroTangent, the call in https://github.com/JuliaDiff/ChainRules.jl/blob/02a8172ad7a4d3d54dc7c680c3c070f3071f3f14/src/rulesets/LinearAlgebra/factorization.jl#L528 errors. Possibly it could be useful to define 3- and 5-argument mul! with AbstractZeroTangent arguments (even though I assume one might run into many method ambiguity issues) but I don't think that (all) the examples in the Zygote tests should use a ZeroTangent there, so the primary issue seems to be that is wrong.

sethaxen commented 2 years ago

so the primary issue seems to be that is wrong.

Can you clarify what you mean here?

devmotion commented 2 years ago

It is a ZeroTangent but I think it shouldn't. The problem is exactly the one @theogf described above: there's no field of name U in the tangent but e.g. of name L or factors (the main difference in the Zygote implementations is actually that it accesses factors instead).

It seems this problem could occur with any hardcoded field access here, so maybe some custom _get_U(tangent) would be needed - if field U exists, it's returned, otherwise it's based on factors, and if that does not exist we transpose L (if it doesn't exist either that should automatically return a ZeroTangent).

sethaxen commented 2 years ago

It is a ZeroTangent but I think it shouldn't. The problem is exactly the one @theogf described above: there's no field of name U in the tangent but e.g. of name L or factors (the main difference in the Zygote implementations is actually that it accesses factors instead).

It would be helpful if we had an MWE to structure this conversation around. Is there one you can construct from the Zygote failures you mentioned?

It seems this problem could occur with any hardcoded field access here, so maybe some custom _get_U(tangent) would be needed - if field U exists, it's returned, otherwise it's based on factors, and if that does not exist we transpose L (if it doesn't exist either that should automatically return a ZeroTangent).

Perhaps, but it might be cleaner to rewrite the rrule for getproperty to accumulate cotangents of factors and then for the rrule of cholesky to work with the cotangent of factors like what we do with lu:

https://github.com/JuliaDiff/ChainRules.jl/blob/c5dbe030af390599848830ff43a5dffc04be69e2/src/rulesets/LinearAlgebra/factorization.jl#L80-L169

devmotion commented 2 years ago

Perhaps, but it might be cleaner to rewrite the rrule for getproperty to accumulate cotangents of factors and then for the rrule of cholesky to work with the cotangent of factors like what we do with lu:

Yes, that sounds simpler.

devmotion commented 2 years ago

Is there one you can construct from the Zygote failures you mentioned?

These Zygote failures happen once one removes the Zygote adjoint for cholesky. Then basically even the simplest examples start to fail. E.g., the linked issue above is triggered by https://github.com/FluxML/Zygote.jl/blob/af8aee4d8acc94bdfa8b9a1c7e16ef0b6a3df32e/test/gradcheck.jl#L650.

sethaxen commented 2 years ago

Is there one you can construct from the Zygote failures you mentioned?

These Zygote failures happen once one removes the Zygote adjoint for cholesky. Then basically even the simplest examples start to fail. E.g., the linked issue above is triggered by https://github.com/FluxML/Zygote.jl/blob/af8aee4d8acc94bdfa8b9a1c7e16ef0b6a3df32e/test/gradcheck.jl#L650.

The main issue there is that the function \ depends on the adjoints for cholesky and \(::Cholesky, ::AbstractVecOrMat), both of which are composeable in Zygote. The linked PR deletes only Zygote's adjoint for cholesky, and its adjoint for \ does not compose well with ChainRules's, since it's written in terms of factors. We should actually have the \ rrule here in ChainRules.

I'll open up PRs to both use factors in the cholesky-related rrules and to migrate the rrule for \ to here.