Closed theogf closed 2 years ago
Ū = ΔC.U
will beNoTangent()
since theTangent
ofCholesky
object does not contain the fieldU
.
Cholesky
does not have the field U
, but it does have the property U
. Then there's this rrule
for getproperty
:
https://github.com/JuliaDiff/ChainRules.jl/blob/c5dbe030af390599848830ff43a5dffc04be69e2/src/rulesets/LinearAlgebra/factorization.jl#L534-L553
Note that the current rules won't compose well if any downstream code accesses factors
, but everything should work well if they access U
or L
. This should probably be improved though to look more like the rules for lu
:
So this assumes that ΔC
here is a Cholesky
object right?
https://github.com/JuliaDiff/ChainRules.jl/blob/c5dbe030af390599848830ff43a5dffc04be69e2/src/rulesets/LinearAlgebra/factorization.jl#L524
Which is not possible it is obviously a tangent :
https://github.com/JuliaDiff/ChainRules.jl/blob/c5dbe030af390599848830ff43a5dffc04be69e2/src/rulesets/LinearAlgebra/factorization.jl#L514
So this assumes that
ΔC
here is aCholesky
object right?
I don't see anywhere it assumes that. The preceding line has the type annotation ΔC::Tangent
. The (co)tangent for Cholesky
will not be a Cholesky
.
Right so ΔC
is a Tangent
! So ΔC.U
will call getproperty
here
https://github.com/JuliaDiff/ChainRulesCore.jl/blob/2d75b4be102bb41ba3ac6df6dec8bb9617b20f0f/src/tangent_types/tangent.jl#L104
And since hasfield(NamedTuple, :U)
is false, it will return NoTangent()
...
T
refers to the type of the backing, which will have fields for whatever keywords were passed to the Tangent
constructor, which can be anything. e.g.
julia> using ChainRulesCore
julia> struct Cholesky
factors
end
julia> F = Cholesky(randn(5, 5));
julia> ΔF = Tangent{typeof(F)}(U=randn(5,5))
Tangent{Cholesky}(U = [-0.7059268385030435 -0.06675693167812352 … -0.2413203958757456 0.03638639429591188; -0.47788794725474393 -1.2764613865307353 … 1.2012293162255412 0.15961345415826841; … ; 0.6552274641328225 0.24813748236499192 … 1.3140512188252975 0.22398055650819915; 2.2445029073821035 -0.5210976005765643 … 1.6755604234385513 0.4181320724677388],)
julia> ΔF.factors
ZeroTangent()
julia> ΔF.U
5×5 Matrix{Float64}:
-0.705927 -0.0667569 -1.42752 -0.24132 0.0363864
-0.477888 -1.27646 -0.652307 1.20123 0.159613
-0.449378 -0.696023 -1.21753 -1.20449 0.23459
0.655227 0.248137 0.0819825 1.31405 0.223981
2.2445 -0.521098 -0.145292 1.67556 0.418132
Note that if this wasn't the current behavior, the rule for cholesky
would always produce a ZeroTangent()
, which it doesn't. Is there a particular surprising behavior/error you encountered that led to this issue?
I think that https://github.com/JuliaDiff/ChainRules.jl/blob/c5dbe030af390599848830ff43a5dffc04be69e2/src/rulesets/LinearAlgebra/factorization.jl#L526 causes some of the test errors in https://github.com/FluxML/Zygote.jl/pull/1114 (e.g., in https://github.com/FluxML/Zygote.jl/runs/6021203148?check_suite_focus=true#step:6:192).
Due to Ū
being a ZeroTangent
, the call in https://github.com/JuliaDiff/ChainRules.jl/blob/02a8172ad7a4d3d54dc7c680c3c070f3071f3f14/src/rulesets/LinearAlgebra/factorization.jl#L528 errors. Possibly it could be useful to define 3- and 5-argument mul!
with AbstractZeroTangent
arguments (even though I assume one might run into many method ambiguity issues) but I don't think that (all) the examples in the Zygote tests should use a ZeroTangent
there, so the primary issue seems to be that Ū
is wrong.
so the primary issue seems to be that
Ū
is wrong.
Can you clarify what you mean here?
It is a ZeroTangent
but I think it shouldn't. The problem is exactly the one @theogf described above: there's no field of name U
in the tangent but e.g. of name L
or factors
(the main difference in the Zygote implementations is actually that it accesses factors
instead).
It seems this problem could occur with any hardcoded field access here, so maybe some custom _get_U(tangent)
would be needed - if field U
exists, it's returned, otherwise it's based on factors
, and if that does not exist we transpose L
(if it doesn't exist either that should automatically return a ZeroTangent
).
It is a
ZeroTangent
but I think it shouldn't. The problem is exactly the one @theogf described above: there's no field of nameU
in the tangent but e.g. of nameL
orfactors
(the main difference in the Zygote implementations is actually that it accessesfactors
instead).
It would be helpful if we had an MWE to structure this conversation around. Is there one you can construct from the Zygote failures you mentioned?
It seems this problem could occur with any hardcoded field access here, so maybe some custom
_get_U(tangent)
would be needed - if fieldU
exists, it's returned, otherwise it's based onfactors
, and if that does not exist we transposeL
(if it doesn't exist either that should automatically return aZeroTangent
).
Perhaps, but it might be cleaner to rewrite the rrule for getproperty
to accumulate cotangents of factors
and then for the rrule of cholesky
to work with the cotangent of factors
like what we do with lu
:
Perhaps, but it might be cleaner to rewrite the rrule for getproperty to accumulate cotangents of factors and then for the rrule of cholesky to work with the cotangent of factors like what we do with lu:
Yes, that sounds simpler.
Is there one you can construct from the Zygote failures you mentioned?
These Zygote failures happen once one removes the Zygote adjoint for cholesky
. Then basically even the simplest examples start to fail. E.g., the linked issue above is triggered by https://github.com/FluxML/Zygote.jl/blob/af8aee4d8acc94bdfa8b9a1c7e16ef0b6a3df32e/test/gradcheck.jl#L650.
Is there one you can construct from the Zygote failures you mentioned?
These Zygote failures happen once one removes the Zygote adjoint for
cholesky
. Then basically even the simplest examples start to fail. E.g., the linked issue above is triggered by https://github.com/FluxML/Zygote.jl/blob/af8aee4d8acc94bdfa8b9a1c7e16ef0b6a3df32e/test/gradcheck.jl#L650.
The main issue there is that the function \
depends on the adjoints for cholesky
and \(::Cholesky, ::AbstractVecOrMat)
, both of which are composeable in Zygote. The linked PR deletes only Zygote's adjoint for cholesky
, and its adjoint for \
does not compose well with ChainRules's, since it's written in terms of factors
. We should actually have the \
rrule
here in ChainRules.
I'll open up PRs to both use factors
in the cholesky
-related rrule
s and to migrate the rrule
for \
to here.
Identified in https://github.com/JuliaStats/PDMats.jl/issues/159
When calling the
cholesky
rrule
, there is an error in the code at https://github.com/JuliaDiff/ChainRules.jl/blob/c5dbe030af390599848830ff43a5dffc04be69e2/src/rulesets/LinearAlgebra/factorization.jl#L526Ū = ΔC.U
will beNoTangent()
since theTangent
ofCholesky
object does not contain the fieldU
. I think it should beŪ = ΔC.factors