Closed willtebbutt closed 3 years ago
I would argue that rules like rrule(f, X::Diagonal)
are ambiguous, and we need to add an extra argument to the rrule
interface to resolve the ambiguity. For example, for f(X::Diagonal) = A*X
where A::Matrix
, both
# Diagonal-as-matrix rule
rrule(::typeof(f), X::Diagonal) = A*X, Ȳ->A'*Ȳ
and
# Diagonal-as-diagonal rule
rrule(::typeof(f), X::Diagonal) = A*X, Ȳ->Diagonal([dot(A[:,i],Ȳ[:,i]) for i = 1:size(A,2)])
are reasonable definitions, but they correspond to two different rrule
s and hence issues arise because the current ChainRules interface does not allow us to distinguish these two. I believe a reasonable way to resolve this issue would be to introduce an additional argument as in
# Diagonal-as-matrix rule
rrule(::typeof(f), ::Type{Tuple{Matrix}}, X::Diagonal) = A*X, Ȳ->A'*Ȳ
# Diagonal-as-diagonal rule
rrule(::typeof(f), ::Type{Tuple{Diagonal}}, X::Diagonal) = A*X, Ȳ->Diagonal([dot(A[:,i],Ȳ[:,i]) for i = 1:size(A,2)])
(The exact form of the Type
argument is up for discussion. For example, it is possible that replacing Matrix
with AbstractMatrix
is better.)
To demonstrate how such rrule
s would be chained, let us consider a function fg(x::AbstractMatrix) = f(g(x)::AbstractMatrix)
. In such a case, the user would have to replace rrule(fg, x::Diagonal)
with either of rrule(fg, Tuple{Matrix}, x::Diagonal)
or rrule(fg, Tuple{Diagonal}, x::Diagonal)
to specify which derivative of fg
they want, but then it is up to the AD tool to figure out which rrule
s to call for f
and g
. For g
, the choice is fairly obvious: rrule(fg, MatOrDiag, x)
must call rrule(g, MatOrDiag, x)
to ensure that the final adjoint matches the user's expectation. For f
, I believe the answer can be worked out as follows: if g(x::MatOrDiag) isa Diagonal
, then call rrule(f, Diagonal, g(x))
, and if g(x::MatOrDiag) isa Matrix
, then call rrule(f, Matrix, g(x))
. The key point here is that it is typeof(g(::MatOrDiag))
, not typeof(g(::typeof(x)))
, which specifies which rrule
to call for f
. Of course, there is also the possibility that g(x)
is not type-stable, in which case it is not clear to me what should happen. However, I believe we can postpone thinking about this until it is more clear that this is truly the direction we want to pursue.
Hmm interesting proposal. If we conclude that they're both reasonable definitions, then that's certainly something that we should consider.
Another way of looking at this problem (that I forgot to mention before) is to consider what would happen in an AD system where you hadn't defined an rrule
for any given operation involving a Diagonal
, but you had defined the rules necessary to backprop through it (e.g. addition and multiplication of scalars in the matrix-multiplication case). If you applied Zygote
to D * X
then the cotangent you'll obtain for D
will be a Composite{typeof(D)}(; diag=some_vector)
, rather than a Matrix
.
So the argument goes as follows:
rrule
emits the correct result (slowly).rrule
is only a performance optimisation, and shouldn't change the output of AD.rrule
outputs a Matrix
rather than a Composite
, it has produced the wrong answer.As regards whether 1 holds or not -- you certainly need to have already defined the rrule
s necessary to be able to differentiate through the code of the operation for which you're defining the operation. For the code that we're talking about I would argue that this likely holds -- e.g. matrix-matrix multiplication just requires getindex
, setindex
(Zygote doesn't currently implement this, but I believe it's well understood what it would implement), Float64 + Float64
, and Float64 * Float64
.
You could think of this argument as being one about consistency. It leans on the idea that there are a small set of basic rules, on whose behaviour everyone can agree, and that all other behaviour follows from there.
I very much agree that it should be possible to split ChainRules into a few rules which define the behaviour and many other rules which provide performance improvements but which could be omitted without breaking any code. I doubt that you can do AD in a sensible way without this property, because it would mean the interface that a downstream programmer works with would depend on what optimisations are currently implemented.
I believe the core, API-defining rules are
frule
s and rrule
s for functions of Real
arguments, and frule
s and rrule
s for constructors, get/setindex
and get/setproperty
. If we decide not to go with the rrule(f, Type{TX...}, x...)
interface mentioned above, then I think we do not have a choice but to follow the diagonal-as-diagonal version mentioned above, i.e.
# Diagonal-as-diagonal rule
rrule(::typeof(f), X::Diagonal) = A*X, (Ȳ::Matrix)->Diagonal([dot(A[:,i],Ȳ[:,i]) for i = 1:size(A,2)])
rrule(::Type{Diagonal}, d) = Diagonal(d), (dD::Diagonal)->diag(dD)
rrule(::typeof(getindex), D::Diagonal, i,j) = D[i,j], (dx::Number)->Diagonal([i == j == k ? dx : zero(dx) for k = 1:size(D,1)])
If we tried to follow the diagonal-as-matrix rule, then we would get into trouble defining the rrule
for the constructor:
# Diagonal-as-matrix rule
rrule(::typeof(f), ::Type{Tuple{Matrix}}, X::Diagonal) = A*X, Ȳ->A'*Ȳ
rrule(::Type{Diagonal}, d) = Diagonal(d), (dD::Matrix) -> # The output has to be a Vector, so what could we possibly do here?
rrule(::typeof(getindex), D::Diagonal, i,j) = D[i,j], (dx::Number)->[ii == i && jj == j ? dx : zero(dx) for ii = 1:size(D,1), jj = 1:size(D,2)]
I very much agree that it should be possible to split ChainRules into a few rules which define the behaviour and many other rules which provide performance improvements but which could be omitted without breaking any cod
This is undecidable without knowing which AD you are talking about. E.g. Tracker, Zygote, Nabla all have different limitations, about what they can AD through without custom rules. Such as Zygotes inability to handle mutation. Thus for each AD the "Instruction set" (as a I have been calling the list of things that they must have rules for) differs.
Further for many, providing rules only for Real
arguements will work, but be so slow that it might as well not.
I guess a more accurate formulation of the point I was trying to make is that we must be able to guarantee that the many rules we might define for a particular type are all compatible with one another, and one way (perhaps the only way?) to do that is to define the core rules and then require that any other rule must have the same signature as if some hypothetical AD tool had created the rule for us. I understand that different concrete AD tools have different capabilities, but I assume when their capabilities overlap they agree in what should be happening, and so I further assume that an ideal AD tool which can create all rules from the minimal core set is well defined.
Also, I am of course not advocating that ChainRules should not provide optimisation rules, just that these rules must be consistent with what we would get if they weren't there.
fair enough. That is a decent way to put it.
One problem is natural differential types.
Like most structuredly sparse matrixes, or Period
s for my pet DateTime
examples.
As a rule an AD system can only make a structured differential types (i.e. Composite
)
but actually you really want to write rules using the Natural differential types
but actually you really want to write rules using the Natural differential types
Have you written this example down anywhere @oxinabox ? Would be good to gather these ideas here. In particular I would really like a crystal clear argument of why we should prefer to work with natural differentials.
Have you written this example down anywhere @oxinabox ? Would be good to gather these ideas here. In particular I would really like a crystal clear argument of why we should prefer to work with natural differentials.
The DateTime
example is discussed here
As a rule an AD system can only make a structured differential types (i.e. Composite) but actually you really want to write rules using the Natural differential types
I believe this is a question of convention. We could either require that all types use Composite
as their differential type, or we could allow each type to choose a suitable differential type but then require that they must be consistent about this. This does not break the "everything must be AD-able" rule, because it is up to us whether we interpret constructors and access functions as part of the core rule set to be defined by humans or something filled-in by AD.
Also, note that "differential type" in the above is meant in a duck-typing sense. It is okay for a type to choose Vector
as its differential type and then return a SparseVector
as a differential since the two types have compatible interfaces. But what would not be okay is for a type to sometimes use Vector
and sometimes Composite
as its differential, because then we would have to write all rules such that they can handle both Vector
and Composite
as input adjoints.
I keep getting sidetracked on writing this perspective up (it's long), so I'll just give the short version and apologize if it's unclear or inaccurate. I agree that ideally the right thing to do would be to write rules that do exactly what AD would have done without the rule, just more efficiently. Right now I really don't like the idea of only defining rules on functions of concrete types though. I implement a lot of custom array types, as do others, and I rely on rules defined on abstract arrays. Without them, e.g. using Zygote, I'm saddled with the (current) terrible performance of the getindex
rules and iteration, and if I want better performance, I'd then need to reimplement a ton of LinearAlgebra rules that will be more or less identical to those for every other array type.
Treating abstract arrays as embedded in the arrays, it's safe to pull back dense array adjoints, and as long as 1) the initial adjoint was valid and 2) the rrule
for the constructor is written as the composition of the constructor with the embedding into the arrays (e.g. see the rrule
for Symmetric
), and 3) the array is only ever treated as an array in the primal functions (so no accessing Symmetric(A).data
, unless you add a custom rule to handle that like for the constructor), then everything should just work out.
But as @willtebbutt points out, you lose the time complexity of the primal function in the pullback. One way to partially get this back would be to "project" the output of each pullback to the predetermined differential type for its primal with a utility function. All rules with abstract types would do this.
So here's an example for *
:
function rrule(::typeof(*), A::AbstractMatrix, B::AbstractMatrix)
function times_pullback(Ȳ)
∂A = @thunk(project_cotangent(*, A, Ȳ * B'))
∂B = @thunk(project_cotangent(*, A, A' * Ȳ))
return (NO_FIELDS, ∂A, ∂B)
end
return A * B, times_pullback
end
# defaults ignores function argument
project_cotangent(f, x, ∂x) = project_cotangent(x, ∂x)
# defaults to a no-op
project_cotangent(x, ∂x) = ∂x
# some possible projections
project_cotangent(x::Array, ∂x) = Array(∂x)
project_cotangent(x::Array, ∂x::Array) = convert(typeof(x), ∂x)
project_cotangent(x::Diagonal, ∂x) = Diagonal(∂x)
project_cotangent(x::LowerTriangular, ∂x) = LowerTriangular(∂x)
project_cotangent(x::UpperTriangular, ∂x) = UpperTriangular(∂x)
project_cotangent(x::Adjoint, ∂x::Adjoint) = ∂x
project_cotangent(x::Adjoint, ∂x) = Adjoint(project_cotangent(x.parent, adjoint(∂x)))
project_cotangent(x::Transpose, ∂x::Transpose) = ∂x
project_cotangent(x::Transpose, ∂x) = Transpose(project_cotangent(x.parent, transpose(∂x)))
project_cotangent(x::Symmetric, ∂x) = Symmetric(project_cotangent(x.data, symmetrize(∂x)), x.uplo)
project_cotangent(x::Hermitian, ∂x) = Hermitian(project_cotangent(x.data, hermitrize(∂x)), x.uplo)
symmetrize(x) = (x .+ transpose(x)) ./ 2
hermitrize(x) = (x .+ x') ./ 2
I don't think this guarantees the right time complexity, but by preserving type information that can then be used for dispatch internally within this and subsequent pullbacks, it will likely be more efficient than the pullback using dense arrays.
The downside of this approach is that if someone implements a new array type but doesn't define a rule for the constructor, then there's a good chance they won't be able to use AD with the array.
I agree with you that we're going to have to make some compromises in the short- to medium- term here to not completely break Zygote. There are probably a few options, and I like the one that you suggest @sethaxen, particularly for matrices that are (often) otherwise dense-with-constraints e.g. the -Triangular
s and Symmetric
etc.
One thing that might be worth working out how custom rules for cases where the asymptotic complexity would otherwise take a hit e.g. Diagonal
s, and the proposed approach work together. Possibly they just do, or maybe it requires some more thought / additional methods? So if you return a Diagonal
matrix to represent the cotangent of an UpperTriangular{T, Diagonal{T}}
, the "right" thing happens and you don't take a complexity hit?
but actually you really want to write rules using the Natural differential types
Have you written this example down anywhere @oxinabox ? Would be good to gather these ideas here. In particular I would really like a crystal clear argument of why we should prefer to work with natural differentials.
At least some of it should be written down in: https://www.juliadiff.org/ChainRulesCore.jl/dev/design/many_differentials.html
The best example I have for why one wants to work with Natural Differentials,
and not always with structural, is matrix factorizations.
Which use Composite
with properties corresponding the the properies of the factorizations.
Not to the fields of the matrix factorizations. (I think in the design docs i was calling this semi-structural?)
Because noone wants to deal with the fields of the matrix factorizations, because they are not nice to work with.
(Probably getting the last bit of performance out would want to do that but thats another question).
For Cholesky: you want to work with the properties U
or L
(https://github.com/JuliaDiff/ChainRules.jl/blob/d3cd83e5d202475fbc18de8fadea69fd9042f66b/src/rulesets/LinearAlgebra/factorization.jl#L87-L100
I don't think you don't want to deal with the true field factors
https://github.com/JuliaLang/julia/blob/110765a87af68120f2f9f4aa0bbc4054db491359/stdlib/LinearAlgebra/src/cholesky.jl#L119-
Though maybe I am wrong in this case, since the relationship of .factors
to .U
and .L
is simple.
https://github.com/JuliaLang/julia/blob/110765a87af68120f2f9f4aa0bbc4054db491359/stdlib/LinearAlgebra/src/cholesky.jl#L411-L423
and it would save the effort of having to define custom addition.
Better example is BunchKaufman
factorization.
The relationship between its fields and its properties is complicated.
https://github.com/JuliaLang/julia/blob/master/stdlib/LinearAlgebra/src/bunchkaufman.jl#L285-L321
though here i don't want to deal with that one at all.
Neither to define addition, not to define the frules / rrules.
But if i must then i would rather do it seperately.
So i can focus on calculus in the rule defination.
Hmm well I think I'm probably okay with using differentials that aren't the Composite
ones provided that they're guaranteed not to kill performance, which I think holds in this case.
I think this is probably the same reason that I think it's probably fine to represent the (co)tangent of a Matrix{Float64}
with a Fill{Float64}
, but not the other way around.
Related is https://github.com/JuliaDiff/ChainRulesCore.jl/issues/176. e.g. if Diagonal
is an okay differential for LowerTriangular
, but Matrix
is not, then it follows that Int
is an okay differential for Float64
, but ComplexF64
is not.
Ohhh so potentially it's generally the case that (if we adopt a strict convention) it's okay for a "bigger" type to have a differential represented by a "smaller" type, but not vice versa.
So
Int
for Float64
is fine, but Float64
for Int
is notFloat64
for ComplexF64
is fine, but ComplexF64
for Float64
is notetc.
That's one idea. The thinking is that the (co)tangent should be (in some sense) embedded in the (co)tangent space of the primal, which could mean it is in the (co)tangent space of a submanifold that also contains the primal or could mean that it is just in a subset of the (co)tangent space of the same manifold.
I do think though that this is a trickier rule to enforce than the project_cotangent
approach proposed above. Because whether the differential is a member of a superset or a subset of the same cotangent space, we can use dispatch to project it to our (co)tangent space, but only projecting if it is in a superset is trickier, because we can't in general differentiate the two for user-defined types.
Here's another case study relating to this issue: https://discourse.julialang.org/t/43778
I am not sure I understand the purpose of project_cotangent()
. If A isa Matrix
and B isa Diagonal
, then project_cotangent(*, A, A' * Ȳ)
involves computing the dense adjoint A' * Ȳ
as an intermediate result, so we don't gain any performance by projecting this adjoint onto Diagonal
. To gain performance, we need to merge rrule()
and project_cotangent()
such that no unnecessary intermediate entries are computed. So is the purpose of project_cotangent()
to serve as a placeholder until an efficient implementation is provided?
Yup, that's exactly it. It's a stop-gap to help out AD systems that don't handle mutation properly and so rely heavily on pullbacks (I'm really just looking at Zygote here tbh)
edit: ofc we would set things up such that you can always optimise stuff still, and we would in all cases we have the time / inclination to sort out.
Any given pullback runs the risk of the cotangent wandering away from the cotangent space of the primal. In general it will be the same cotangent, just embedded in a larger cotangent space. For example, in A * B
, if A
is real and B
is complex, the output will be complex, the cotangent of the output will (probably) be complex, and a naive pullback will give A
a complex cotangent. Without accounting for this, we should expect cotangents to wander further and further away from the cotangent space of the primals. For reals, they will tend to complex; for AbstractArray
s, they will tend to Array
s.
This is problematic for 2 reasons: 1) the representation in the larger cotangent space will generally be less efficient, where operations have a worse time complexity than the primal. 2) as the program becomes larger and more complex with more user-defined rules, there's a higher chance that a pullback is encountered that makes assumptions about the cotangent (such as "the cotangent is real") that are violated by this cotangent vector. This can result in the computed gradient just being wrong.
project_cotangent
is an imperfect solution. Frankly I don't like it. You could have the pullback closure call another function that dispatches on the inputs, outputs, and adjoint, but this is just a band-aid on the problem. A user can always implement a new structured array type for which a custom pullback would be needed to have the right time complexity. I'm coming around to the idea that rules for abstract types are a bad idea, but not as bad as utter failure for non-mutating AD.
It's a stop-gap to help out AD systems that don't handle mutation properly
Why does mutation matter in this context? The Diagonal
vs Matrix
example does not involve mutation.
I'm coming around to the idea that rules for abstract types are a bad idea, but not as bad as utter failure for non-mutating AD
Isn't this exactly the same as with standard dispatch? For example, when you create a type MyMatrix <: AbstractMatrix
, then there's a generic fallback for matrix products involving MyMatrix
, but performance will likely suck, in particular if MyMatrix
is something sparse like Diagonal
. If you want performance, you have to bite the bullet and actually implement your optimised matrix product.
As far as I can tell, it's the same here. rrule(::typeof(*), A::AbstractMatrix, B::AbstractMatrix) = A*B, dC -> (NO_FIELDS, dC*B', A'*dC)
is a reasonable fallback, but for many concrete types, this won't be performant enough so you have to buckle down and implement a specialised rrule
.
One vote that project_cotangent(x, ∂x)
sounds like a good idea. It seems unlikely that anyone will write all possible rules like rrule(::typeof(*), A::UpperTriangular, B::AbstractMatrix)
, and I doubt that most of them would be more efficient than projecting in any case -- N^2 BLAS vs 1/2 N^2 on some ragged shape.
Diagonal
is the obvious exception where a special rule really will be much more efficient. But since we can't dispatch on ::SetDiff{AbstractMatrix, Diagonal}
in order not to supply a rule at all, projecting the N^2 result of a naiive rule down to N seems like a decent way to return the right answer inefficiently until someone gets around to writing a rule. Maybe Symmetric
/ Hermitian
are also cases where specialised rules could be faster.
Possibly it should be project_cotangent!!(typeof(x), ∂x)
, with license to mutate if possible & faster, and without keeping the whole x
around if not otherwise needed.
Re number types, I'm not sure https://github.com/JuliaDiff/ChainRules.jl/issues/232#issuecomment-660298499 is quite right. This is perfect: "Float64 for ComplexF64 is fine, but ComplexF64 for Float64 is not". But for integers, promoting to float seems perfectly fine. If you ask for a derivative at 3, you are asking about infinitesimal changes away from 3, and should be unsurprised to get a real number. But if 3 is a local minimum of f
, you would be surprised to be told that the downhill direction is 0 + im
(because f
happens to use complex numbers internally) which I think is what Zygote now does. It is reasonable (IMO) to insist that you ask the question with gradient(f, 3+0im)
if you want a complex answer.
Re number types, I'm not sure #232 (comment) is quite right. This is perfect: "Float64 for ComplexF64 is fine, but ComplexF64 for Float64 is not". But for integers, promoting to float seems perfectly fine. If you ask for a derivative at 3, you are asking about infinitesimal changes away from 3, and should be unsurprised to get a real number. But if 3 is a local minimum of f, you would be surprised to be told that the downhill direction is 0 + im (because f happens to use complex numbers internally) which I think is what Zygote now does. It is reasonable (IMO) to insist that you ask the question with gradient(f, 3+0im) if you want a complex answer.
I can certainly see your point; it seems perfectly reasonable to promote Integer
s to other subtypes of Real
until you start thinking about consistency. To my mind "ComplexF64 for Float64 is not fine" implies that "Real for Integer is not fine", and "Matrix for Diagonal not is fine" and, more general, that "interpreting any particular type as being embedded inside another is, for the sake of AD, not fine".
As regards people's expectations / intuition -- it's taken the AD community in Julia a surprisingly long time to realise how much of an issue this all is, so I would expect some of the results to be surprising to users of AD.
I suspect that the reason it's taken this long for it to become obvious is that it was much less of an issue when we didn't have to be able to define appropriate tangents for any given type, back when Tracker / ReverseDiff were being written. All you had to do there was define them for Matrix
and Float64
/ Float32
-- technically the trackable objects were subtypes of AbstractMatrix
and Real
respectively, but the only way in which anyone used them in practice was as wrappers around Matrix{Float64}
/ Matrix{Float32}
and Float64
/ Float32
, for which these issues simply don't arise.
I came across this issue a couple of time in the wild recently. I'm reporting on them here so that I don't forget about them.
Until recently, logdet
was implemented for any AbstractMatrix
in ChainRules
which meant that e.g. whenever you asked Zygote
to differentiate through logdet
of a PDMat
you hit this rule rather than the other one. The interesting thing in this case is that it will probably cause some numerical issues as the generic implementation involves explicitly constructing the inverse. This is something that the projection approach wouldn't be able to recover from.
We've also had an issue in KernelFunctions.jl whereby map
is implemented too widely in Zygote
. We had a couple of custom implementations of map
for some special AbstractVector
s for which the optimal thing to do would have been to naively differentiate through our implementation rather than utilise the rule in Zygote
. In the end we would up implementing a function called _map
to work around the issue, which is really a shame.
an issue in KernelFunctions.jl whereby map is implemented too widely in Zygote
I'm not sure I understand. You have a weird type <:AbstractVector
for which Zygote defaults to something generic and slow. You would like to dispatch to a faster implementation. That seems very normal and very Julian. Are you saying that @adjoint Base.map(f::Transform, x::ColVecs)
doesn't get called in your case? It seems to get called when I make a toy example, is this because something has changed? The more complicated code on master seems to have the same dispatch as the one you linked (9 May) but other changes.
logdet of a PDMat you hit this rule rather than the other one.
Does "other one" mean one defined specifically for :: PDMat
? ("this" is the AbstractMatrix rule.) And again it sounds like dispatch ought to work, what am I missing?
Re integers, I guess the possible options are (1) they are categorical, all derivatives Zero()
, or (2) their derivatives must be of (at most) the same type, or (3) they are points on the real line like any other, implicitly float unless the algebra happens to work out.
Zygote takes (3) right now. That's also the rule of every blackboard lecture on calculus ever. Option (1) would break for instance every example in Zygote's readme, which seems surprising. Under option (2), sin'(0)===1
is fine, while sin'(1)
is an InexactError.
I don’t think a choice here implies rules for complex (or matrix) types. Integers really are different. You can do calculus over R, or over C (or worse), and over R^N etc, but you can’t do calculus over the integers. So either you don’t at all, (1), or you declare them to be real numbers which happen take up less blackboard space, (3).
I hadn’t seen #224, but it doesn’t do anything special for integers right? It won’t produce a complex gradient for the power in x^p
unless you provide p::Complex
. Which fits the general rule of not using a tangent space with more dimensions than the input.
I'm not sure I understand. You have a weird type <:AbstractVector for which Zygote defaults to something generic and slow. You would like to dispatch to a faster implementation. That seems very normal and very Julian. Are you saying that @adjoint Base.map(f::Transform, x::ColVecs) doesn't get called in your case? It seems to get called when I make a toy example, is this because something has changed? The more complicated code on master seems to have the same dispatch as the one you linked (9 May) but other changes.
Zygote's machinery works as intended / custom @adjoints
would work fine, the point is that we shouldn't have to define them in the first place.
The usual rationale around falling back to generic implementations when you've not got a specialised one in Julia does not apply to AD.
The reason that it's acceptable to fall back to a slow definition in general Julia code is because you've not written specialised code, so a slow fallback is better than not having anything. Going with the running example of a Diagonal
matrix, if you had only implemented *(::AbstractMatrix, ::AbstractMatrix)
and multiplied two Diagonal
matrices together, you wouldn't expect it to automatically exploit the diagonal structure in the matrix, because you've not written the code to do that. That's fine.
The same is not true for AD. Suppose you have indeed written *(::Diagonal, ::Diagonal)
and invoke *
on a pair of Diagonal
matrices. Cleary, the usual thing will happen and the code will be specialised to the appropriate method when executing the forwards pass. Now consider doing AD -- assuming that you've not written a rule that applies to *(::Diagonal, ::Diagonal)
, the AD tool will recurse into the definition of *(::Diagonal, ::Diagonal)
and derive the appropriate rule in terms of the operations in which it is implemented.
If, on the other hand, one implements an adjoint for *
on AbstractMatrix
s, the AD tool is prevented from recursing into the code, as would be optimal.
This is analogous to what is happening here with map
. Zygote has too-generic an @adjoint
defined, and as a consequence wasn't able to recurse into the definition of map
for a ColVecs
object.
Does "other one" mean one defined specifically for :: PDMat? ("this" is the AbstractMatrix rule.) And again it sounds like dispatch ought to work, what am I missing?
Apologies, I definitely wasn't sufficiently clear before.
In ChainRules
we used to have a rule with the following type signature
rrule(::typeof(logdet), X::AbstractMatrix) = ...
and in Zygote
we have this rule
@adjoint logdet(C::Cholesky) = ...
PDMats
defines the logdet
of a PDMat
here to be
LinearAlgebra.logdet(a::PDMat) = logdet(a.chol)
where a.chol isa Cholesky
and PDMat <: AbstractMatrix
. There are no custom rules specific to PDMats
/ AbstractPDMats
.
The desired behaviour is that Zygote
recurses into the logdet(::PDMat)
method, sees the Cholesky
factorisation, and invokes the @adjoint logdet(C::Cholesky) = ...
. Instead it was hitting logdet(::AbstractMatrix)
and immediately calling rrule(::typeof(logdet), X::AbstractMatrix)
rather than recursing a single step further and hitting the appropriate rule.
Regarding your comment on the integers: I sympathise with your position on the matter, and agree that it would be quite jarring from a user's perspective (and mine, when I'm being a user).
It would certainly be possible for a given AD tool to take a different strategy from ChainRules at the user-facing level. For example, the default behaviour might be to convert Integer
s to Float64
s before doing AD on a function, with an option to override this behaviour if a user really means Integer
, not Float64
.
I still believe that consistency has to be the goal of a rules system though.
I hadn’t seen #224, but it doesn’t do anything special for integers right? It won’t produce a complex gradient for the power in x^p unless you provide p::Complex. Which fits the general rule of not using a tangent space with more dimensions than the input.
Maybe @sethaxen can comment on this? I believe he wrote the code and understands the issue best.
agree that it would be quite jarring from a user's perspective
To be clear, you are arguing for what I called option (1), or option (2)?
And can you phrase exactly what you mean by consistency here? One way to phrase it is "use the tangent space implicitly defined by dual numbers". This has 1 real dimension in the case of both floats & integers, 2 real dimensions for ComplexF64. And N not N^2 real dimensions in the case of a Diagonal matrix.
To be clear, you are arguing for what I called option (1), or option (2)?
Consistency suggests (2), but the implication of not allowing a float-valued tangent for an Integer
is that you have to use something Integer
-like as a tangent, which doesn't seem to make sense. To me this implies that (1) is the most appropriate thing to do.
And can you phrase exactly what you mean by consistency here? One way to phrase it is "use the tangent space implicitly defined by dual numbers". This has 1 real dimension in the case of both floats & integers, 2 real dimensions for ComplexF64. And N not N^2 real dimensions in the case of a Diagonal matrix.
I agree with the dimensionality aspect of what you're saying, and I agree with your characterisation of the dimensionality of the various tangent spaces defined my the dual numbers.
Although now that you're forcing me to be really specific (thanks for engaging with this, this is a helpful discussion) I think my point about consistency has been conflating two distinct issues:
Float64
, Float32
, Int64
, Rational
etc.We seem to have some kind of consensus around point number 1.
Our discussions around point 2 is trickier though. There are clearly situations in which someone writes 1
and really means 1.0
or 1f0
, but there are others in which they really mean 1
as in "a thing that I can count with" e.g. when writing a for-loop or indexing into an array.
I don't think that we have the ability to distinguish between the two in general other than from context. Seth's example with matrix powers feels like an example where it really matters, but I'm not completely sure.
Powers are tricky, but are the two problems obviously entangled? I wonder whether, in cases where previously Zygote returned a complex gradient for a real power, it shouldn't return something special to say "don't trust this gradient" which was often not needed anyway. Maybe a variant of Zero()
, maybe just NaN
? But perhaps I should think slower, and perhaps this long page isn't the right place.
Re dual numbers, if Dual(3, (true, false)) |> sin
runs without error, does this imply we are in option (3)?
And, a corner case is whether a SparseMatrix has dimension N^2, or nnz(A)
. I'd argue for the former, but notice this:
julia> A = spzeros(3,3); A[2,3] = 1; ForwardDiff.gradient(sum, A)
3×3 SparseMatrixCSC{Float64,Int64} with 1 stored entry:
[2, 3] = 1.0
Re types in general, should x::Float32
be permitted gradient dx::Float64
? Here pure maths is no help at all, obviously. But this promotion appears to be a common performance bug for Flux, as one x .* 1/size(x,2)
will route everything to generic_matmul.
Thanks for explaining why methods come with additional headache in AD as compared to "standard" Julia, @willtebbutt ! This issue has been mentioned in a number of discussions in this repo, and I never quite understood what the problem was, but now I do.
However, I am not sure I agree with the proposed solution. The fundamental problem is that we have two different axes along which we would like the compiler / AD tool to auto-generate code for us, namely the specialisation and differentiation axes. What you suggest in your post is that specialisation should take precedence over differentiation, but I am not sure that this is the right thing to do in all circumstances. Maybe this situation is the AD equivalent of an ambiguity error which needs to be manually resolved by a human?
The above issue is yet another example of this issue cropping up in practice. It's an interesting example that feels morally a bit different from the other examples here in that it's a rule that doing quite high-level stuff.
A couple of other misc. thoughts on this, in no particular order:
*(::Diagonal, ::Matrix)
.rrule
specialised to that method (i.e. the method that Zygote
would derive -- it doesn't matter when Zygote
actually derives it, it exists from the moment that you define the method because everything is basically stateless). When you manually define an rrule
for a more generic method of said function, said method takes precedence over the more specialised method. In this sense, it operates in the opposite manner to standard dispatch. So one fix for this would be to figure out how to ensure that if you define a really generic rrule
, it doesn't override the rrule
s automatically derived by more specific functions. If you could achieve this, you could ensure that rrule
s are opt-in for new more specific methods of function, with the default thing that the AD tool would do taking precedence by default. Might be a better option as people being to write more AD-aware code (i.e. avoiding mutation) and AD tooling continues to improve.Not sure how helpful either of these are, but maybe food for thought 🤷
We have another contender! https://github.com/FluxML/Zygote.jl/issues/899
I think the “type constraints” label is rolling together several levels of too-broad definitions:
In Zygote# 899 [Edit -- should be JuliaMath/FFTW.jl/issues/182], a function defined on an abstract type tries to access fields which not every subtype posses. That's a problem long before talking about AD; at least it is [in # 182 at least] a noisy one.
Next is things like https://github.com/FluxML/Zygote.jl/issues/916, where the rule for dot(::Any, ::Any)
is giving the wrong answer, for inputs it claims to accept. (It appears to assume that their shapes match.) [Edit -- maybe FluxML/Zygote.jl/issues/899 is in this class too, in fact.]
Maybe the next class is rules accepting ::Number
but failing to preserve real numbers. Or similarly failing to preserve Symmetric <: AbstractMatrix
. (Or to preserve Diagonal
in something like ForwardDiff.gradient(x -> sum(prod(x, dims=1)), diagm(2:3))
.) As discussed above, https://github.com/JuliaDiff/ChainRules.jl/issues/232#issuecomment-657170934, a projection operator might be a better solution than than demanding that every rule be super-careful --- it would be a lot of work to allow for f(::Real, ::Complex)
by hand every time. https://github.com/FluxML/Zygote.jl/pull/965 is a recent attempt at this.
Finally there are rules which are mathematically correct for the types they claim to support, but are inefficient. Often Diagonal
is the obvious example, but the best Diagonal
examples of this are the ones when the gradient is a diagm(dx)
, correct but inefficient, rather than ones where off-diagonal elements are nonzero.
This last class is the main concern here:
When you manually define an rrule for a more generic method of said function, said [rrule] method takes precedence over the more specialised method [implicitly defined by the forward code. ...] So one fix for this would be to figure out how to ensure that [...] a really generic rrule, it doesn't override the rrules automatically derived by more specific functions.
I don't think I appreciated this before, but the design of ChainRules is such that you can disable a generic rrule
, when you know that the more specialised method you just defined is AD-friendly. It doesn’t quite work yet to define a specific rrule
method that returns nothing
, because its Zygote interaction checks the method table, but perhaps this can be fixed: https://github.com/FluxML/Zygote.jl/pull/967 is a start.
It doesn’t quite work yet to define a specific rrule method that returns nothing, because its Zygote interaction checks the method table, but perhaps this can be fixed: FluxML/Zygote.jl#967 is a start.
Care also needs to be taken that this mechanism works for Nabla, and Yota.
(I have been told it already works for Diffractor, abd apparently it works for Diffractor even if it isn't inferrable to Compiler.return_type
.)
Oh nice, I didn't know Yota was on board, since https://github.com/dfdx/Yota.jl/pull/85 it seems. And it looks like https://github.com/invenia/Nabla.jl/pull/189 is the Nabla PR. Both look to be about as tricky as Zygote, i.e. they aren't just calling the rrule
, which would automatically work.
Various discussions have been had in various places about the correct kinds of types to implement
rrule
s for, but we've not discussed this in a central location. This problem probably occurs for somefrule
s, but doesn't seem as prevalent as in therrule
case.Problem Statement
The general theme of the problem is whether or not to view certain types as being "embedded" inside others or not, for the purpose of computing derivatives. For example, is a
Diagonal
matrix one that just happens to be diagonal and is equally-well thought of as aMatrix
, or is it really it's own thing? Similarly, is anInteger
just aReal
, or is it something else?I will first attempt to demonstrate the implications of each choice with a couple of representative examples.
Diagonal
matrix multiplicationConsider the following
rrule
implementation:If
X
andY
areMatrix{Float64}
s for example, then this is a perfectly reasonable implementation --ΔΩ
should also be aMatrix{Float64}
if whoever is calling this rule is calling it correctly.Things break down if
X
is aDiagonal{Float64}
. The forwards-pass is completely fine, as is the computation of the cotangent forY
,X' * ΔΩ
. However, the complexity of the cotangent representation / computation forX
is now very concerning --ΔΩ * Y'
produces aMatrix{Float64}
. Such a matrix is specified byO(N^2)
numbers rather thanO(N)
required forX
, and requiresO(N^3)
-time to compute, as opposed to the forwards-pass complexityO(N^2)
. This breaks the promise that the forwards- and reverse-pass time- and memory-complexity of reverse-mode AD should be the same, in essence rendering the structure in aDiagonal
matrix if used in an AD system where this is the only rule for multiplication of matrices.Moreover, what does it mean to consider a non-zero "gradient" w.r.t. the off-diagonal elements of a
Diagonal
matrix? If you take the view that it's just anotherMatrix
, then there's no issue. The other point of view is that there's no meaningful way to define a non-zero gradient w.r.t. the off-diagonal elements of aDiagonal
matrix without considering matrices outside of the space ofDiagonal
matrices -- intuitively, if you "perturb" an off-diagonal element, you no longer have aDiagonal
matrix. Consequently, aMatrix
isn' an appropriate type to represent the gradient of aDiagonal
. If someone has a way to properly formalise this argument, please consider providing it.It seems that the first view necessitates giving up on the complexity guarantees that reverse-mode AD provides, while the second view necessitates giving up on implementing
rrule
s for abstract types (roughly speaking). The former is (in my opinion) a complete show-stopper, while the latter is something we can in principle live with.Of course you could add a specialised implementation for
Diagonal
matrices. However, I would suggest that you ponder your favourite structured matrix type and try to figure out whether it has similar issues. Most of the structured matrix types that I have encountered suffer from precisely this issue with many operations defined on them -- only those that are "dense" in the same way that aMatrix
is do not. Consequently, it is not the case that we'll eventually reach a point where we've implemented enough specialised rules -- people will keep creating new subtypes ofAbstractMatrix
and we'll be stuck in a cycle of forever adding new rules. This seems sub-optimal given that a reasonable AD aught to be able to derive them for us. Moreover, whenever someone who isn't overly familiar with AD implements a newAbstractMatrix
, they would need to implement a host of newrrule
s, which also seems like a show-stopper.Number
multiplicationNow consider implementing an
rrule
for*
between two numbers. Clearlyis a correctly implemented rule.
Float64
is concrete, so there's no chance that someone will subtype it and require a different implementation for their subtype. In this sense, we can guarantee that this rule is correct for any of the inputs that it admits (up to finite-precision arithmetic issues).What would happen if you implemented this rule instead for
Real
s? Suppose someone provided anInteger
argument fory
, then its cotangent will be probably be aFloat64
. While this doesn't provide the same complexity issues as theDiagonal
example above, treating theInteger
s as being embedded in theReal
s can cause some headaches, such as the one's that @sethaxen addressed in https://github.com/JuliaDiff/ChainRules.jl/pull/224 -- where it becomes very important to distinguish betweenInteger
andReal
exponents for the sake of performance and correctness. Since it is presumably not acceptable to sometimes treat theInteger
s as special cases of theReal
s and some times not, it follows that*
should not be implemented betweenReal
s, but betweenAbstractFloat
s if we're actually trying to be consistent.Will this cause issues for users? Seth pointed out that the only situation in which this is likely to be problematic is the case in which an
Integer
argument is provided to an AD tool. This doesn't seem like a show stopper.What Gives?
The issue seems to stem from implementing rules for types that you don't know about. For example, you can't know whether the
*
implementation above is suitable for all concrete matrices that sub-typeAbstractMatrix
, even if they otherwise seem like perfectly reasonable matrix types.How does this square with the typical uses of multiple dispatch within Julia? One common line-of-thought is roughly "multiple dispatch seems to work just fine in the rest of Julia, so why can't we just implement things as they're needed in this case?". The answer seems to be that
AbstractMatrix
s will generally give the correct answer. This simply doesn't hold if you're inclined to take the view that aMatrix
isn't a suitable type to represent the gradient w.r.t. aDiagonal
, andWhat To Do?
The obvious thing to do is to advise that rules are only implemented when you know they'll be correct, which means you have to restrict yourself to implementing rules for concrete types, and collections thereof. Unfortunately, doing this will almost certainly break e.g. Zygote because it relies heavily on very broadly-defined rules that in general exhibit all of the problems discussed above. Maybe some careful middle ground needs to be found, or maybe Zygote needs to press forward with it's handling of mutation so that it can actually handle the things that it can deal with such changes.
Related Work
https://github.com/JuliaDiff/ChainRules.jl/issues/81
Writing this issue was motivated by https://github.com/JuliaDiff/ChainRules.jl/pull/226, where @nickrobinson251 pointed out that we should have an issue about this, and that we should consider adding something to the docs.
@sethaxen @oxinabox @MasonProtter any thoughts?
@DhairyaLGandhi @CarloLucibello this is very relevant for Zygote, so could do with your input.