JuliaDiff / ChainRules.jl

forward and reverse mode automatic differentiation primitives for Julia Base + StdLibs
Other
435 stars 89 forks source link

Type Constraints #232

Closed willtebbutt closed 3 years ago

willtebbutt commented 4 years ago

Various discussions have been had in various places about the correct kinds of types to implement rrules for, but we've not discussed this in a central location. This problem probably occurs for some frules, but doesn't seem as prevalent as in the rrule case.

Problem Statement

The general theme of the problem is whether or not to view certain types as being "embedded" inside others or not, for the purpose of computing derivatives. For example, is a Diagonal matrix one that just happens to be diagonal and is equally-well thought of as a Matrix, or is it really it's own thing? Similarly, is an Integer just a Real, or is it something else?

I will first attempt to demonstrate the implications of each choice with a couple of representative examples.

Diagonal matrix multiplication

Consider the following rrule implementation:

function rrule(::typeof(*), X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real})
    function mul_AbstractMatrix_pullback(ΔΩ::AbstractMatrix{<:Real})
        return ΔΩ * Y', X' * ΔΩ
    end
    return X * Y, mul_AbstractMatrix_pullback
end

If X and Y are Matrix{Float64}s for example, then this is a perfectly reasonable implementation -- ΔΩ should also be a Matrix{Float64} if whoever is calling this rule is calling it correctly.

Things break down if X is a Diagonal{Float64}. The forwards-pass is completely fine, as is the computation of the cotangent for Y, X' * ΔΩ. However, the complexity of the cotangent representation / computation for X is now very concerning -- ΔΩ * Y' produces a Matrix{Float64}. Such a matrix is specified by O(N^2) numbers rather than O(N) required for X, and requires O(N^3)-time to compute, as opposed to the forwards-pass complexity O(N^2). This breaks the promise that the forwards- and reverse-pass time- and memory-complexity of reverse-mode AD should be the same, in essence rendering the structure in a Diagonal matrix if used in an AD system where this is the only rule for multiplication of matrices.

Moreover, what does it mean to consider a non-zero "gradient" w.r.t. the off-diagonal elements of a Diagonal matrix? If you take the view that it's just another Matrix, then there's no issue. The other point of view is that there's no meaningful way to define a non-zero gradient w.r.t. the off-diagonal elements of a Diagonal matrix without considering matrices outside of the space of Diagonal matrices -- intuitively, if you "perturb" an off-diagonal element, you no longer have a Diagonal matrix. Consequently, a Matrix isn' an appropriate type to represent the gradient of a Diagonal. If someone has a way to properly formalise this argument, please consider providing it.

It seems that the first view necessitates giving up on the complexity guarantees that reverse-mode AD provides, while the second view necessitates giving up on implementing rrules for abstract types (roughly speaking). The former is (in my opinion) a complete show-stopper, while the latter is something we can in principle live with.

Of course you could add a specialised implementation for Diagonal matrices. However, I would suggest that you ponder your favourite structured matrix type and try to figure out whether it has similar issues. Most of the structured matrix types that I have encountered suffer from precisely this issue with many operations defined on them -- only those that are "dense" in the same way that a Matrix is do not. Consequently, it is not the case that we'll eventually reach a point where we've implemented enough specialised rules -- people will keep creating new subtypes of AbstractMatrix and we'll be stuck in a cycle of forever adding new rules. This seems sub-optimal given that a reasonable AD aught to be able to derive them for us. Moreover, whenever someone who isn't overly familiar with AD implements a new AbstractMatrix, they would need to implement a host of new rrules, which also seems like a show-stopper.

Number multiplication

Now consider implementing an rrule for * between two numbers. Clearly

function rrule(::typeof(*), x::Float64, y::Float64)
    function mul_Float64_pullback(ΔΩ::Float64)
        return ΔΩ * y, ΔΩ * x
    end
    return x * y, mul_Float64_pullback
end

is a correctly implemented rule. Float64 is concrete, so there's no chance that someone will subtype it and require a different implementation for their subtype. In this sense, we can guarantee that this rule is correct for any of the inputs that it admits (up to finite-precision arithmetic issues).

What would happen if you implemented this rule instead for Reals? Suppose someone provided an Integer argument for y, then its cotangent will be probably be a Float64. While this doesn't provide the same complexity issues as the Diagonal example above, treating the Integers as being embedded in the Reals can cause some headaches, such as the one's that @sethaxen addressed in https://github.com/JuliaDiff/ChainRules.jl/pull/224 -- where it becomes very important to distinguish between Integer and Real exponents for the sake of performance and correctness. Since it is presumably not acceptable to sometimes treat the Integers as special cases of the Reals and some times not, it follows that * should not be implemented between Reals, but between AbstractFloats if we're actually trying to be consistent.

Will this cause issues for users? Seth pointed out that the only situation in which this is likely to be problematic is the case in which an Integer argument is provided to an AD tool. This doesn't seem like a show stopper.

What Gives?

The issue seems to stem from implementing rules for types that you don't know about. For example, you can't know whether the * implementation above is suitable for all concrete matrices that sub-type AbstractMatrix, even if they otherwise seem like perfectly reasonable matrix types.

How does this square with the typical uses of multiple dispatch within Julia? One common line-of-thought is roughly "multiple dispatch seems to work just fine in the rest of Julia, so why can't we just implement things as they're needed in this case?". The answer seems to be that

  1. while generic fallbacks in Julia can be slow, they're at least correct if your type correctly implements whichever interface it is meant to e.g. the multiplication of two AbstractMatrixs will generally give the correct answer. This simply doesn't hold if you're inclined to take the view that a Matrix isn't a suitable type to represent the gradient w.r.t. a Diagonal, and
  2. general Julia code doesn't provide you with asymptotic complexity guarantees in it's fallbacks -- if you encounter a fallback and it has far worse complexity than you know is achievable for a particular operation on your type, you're unlikely to be too annoyed -- how could you possibly expect Julia to magically e.g. exploit some special structure in your matrix type unless you tell it how to? This not the case with AD because a) reverse-mode AD provides complexity guarantees and b) the code to exploit structure was written to implement the forwards-pass, so a reasonable AD aught to be able to exploit it to construct an efficient pullback. It's very frustrating when this has been prevented from happening through the implementation of a rule that "over-reaches" and prevents an AD tool from doing what it's meant to be doing.

What To Do?

The obvious thing to do is to advise that rules are only implemented when you know they'll be correct, which means you have to restrict yourself to implementing rules for concrete types, and collections thereof. Unfortunately, doing this will almost certainly break e.g. Zygote because it relies heavily on very broadly-defined rules that in general exhibit all of the problems discussed above. Maybe some careful middle ground needs to be found, or maybe Zygote needs to press forward with it's handling of mutation so that it can actually handle the things that it can deal with such changes.

Related Work

https://github.com/JuliaDiff/ChainRules.jl/issues/81

Writing this issue was motivated by https://github.com/JuliaDiff/ChainRules.jl/pull/226, where @nickrobinson251 pointed out that we should have an issue about this, and that we should consider adding something to the docs.

@sethaxen @oxinabox @MasonProtter any thoughts?

@DhairyaLGandhi @CarloLucibello this is very relevant for Zygote, so could do with your input.

ettersi commented 4 years ago

I would argue that rules like rrule(f, X::Diagonal) are ambiguous, and we need to add an extra argument to the rrule interface to resolve the ambiguity. For example, for f(X::Diagonal) = A*X where A::Matrix, both

# Diagonal-as-matrix rule
rrule(::typeof(f), X::Diagonal) = A*X, Ȳ->A'*Ȳ

and

# Diagonal-as-diagonal rule
rrule(::typeof(f), X::Diagonal) = A*X, Ȳ->Diagonal([dot(A[:,i],Ȳ[:,i]) for i = 1:size(A,2)])

are reasonable definitions, but they correspond to two different rrules and hence issues arise because the current ChainRules interface does not allow us to distinguish these two. I believe a reasonable way to resolve this issue would be to introduce an additional argument as in

# Diagonal-as-matrix rule
rrule(::typeof(f), ::Type{Tuple{Matrix}}, X::Diagonal) = A*X, Ȳ->A'*Ȳ

# Diagonal-as-diagonal rule
rrule(::typeof(f), ::Type{Tuple{Diagonal}}, X::Diagonal) = A*X, Ȳ->Diagonal([dot(A[:,i],Ȳ[:,i]) for i = 1:size(A,2)])

(The exact form of the Type argument is up for discussion. For example, it is possible that replacing Matrix with AbstractMatrix is better.)

To demonstrate how such rrules would be chained, let us consider a function fg(x::AbstractMatrix) = f(g(x)::AbstractMatrix). In such a case, the user would have to replace rrule(fg, x::Diagonal) with either of rrule(fg, Tuple{Matrix}, x::Diagonal) or rrule(fg, Tuple{Diagonal}, x::Diagonal) to specify which derivative of fg they want, but then it is up to the AD tool to figure out which rrules to call for f and g. For g, the choice is fairly obvious: rrule(fg, MatOrDiag, x) must call rrule(g, MatOrDiag, x) to ensure that the final adjoint matches the user's expectation. For f, I believe the answer can be worked out as follows: if g(x::MatOrDiag) isa Diagonal, then call rrule(f, Diagonal, g(x)), and if g(x::MatOrDiag) isa Matrix, then call rrule(f, Matrix, g(x)). The key point here is that it is typeof(g(::MatOrDiag)), not typeof(g(::typeof(x))), which specifies which rrule to call for f. Of course, there is also the possibility that g(x) is not type-stable, in which case it is not clear to me what should happen. However, I believe we can postpone thinking about this until it is more clear that this is truly the direction we want to pursue.

willtebbutt commented 4 years ago

Hmm interesting proposal. If we conclude that they're both reasonable definitions, then that's certainly something that we should consider.

Another way of looking at this problem (that I forgot to mention before) is to consider what would happen in an AD system where you hadn't defined an rrule for any given operation involving a Diagonal, but you had defined the rules necessary to backprop through it (e.g. addition and multiplication of scalars in the matrix-multiplication case). If you applied Zygote to D * X then the cotangent you'll obtain for D will be a Composite{typeof(D)}(; diag=some_vector), rather than a Matrix.

So the argument goes as follows:

  1. Assume that an AD that doesn't have a particular new rrule emits the correct result (slowly).
  2. The purpose of adding a new rrule is only a performance optimisation, and shouldn't change the output of AD.
  3. Therefore if the new rrule outputs a Matrix rather than a Composite, it has produced the wrong answer.

As regards whether 1 holds or not -- you certainly need to have already defined the rrules necessary to be able to differentiate through the code of the operation for which you're defining the operation. For the code that we're talking about I would argue that this likely holds -- e.g. matrix-matrix multiplication just requires getindex, setindex (Zygote doesn't currently implement this, but I believe it's well understood what it would implement), Float64 + Float64, and Float64 * Float64.

You could think of this argument as being one about consistency. It leans on the idea that there are a small set of basic rules, on whose behaviour everyone can agree, and that all other behaviour follows from there.

ettersi commented 4 years ago

I very much agree that it should be possible to split ChainRules into a few rules which define the behaviour and many other rules which provide performance improvements but which could be omitted without breaking any code. I doubt that you can do AD in a sensible way without this property, because it would mean the interface that a downstream programmer works with would depend on what optimisations are currently implemented.

I believe the core, API-defining rules are

If we decide not to go with the rrule(f, Type{TX...}, x...) interface mentioned above, then I think we do not have a choice but to follow the diagonal-as-diagonal version mentioned above, i.e.

# Diagonal-as-diagonal rule
rrule(::typeof(f), X::Diagonal) = A*X, (Ȳ::Matrix)->Diagonal([dot(A[:,i],Ȳ[:,i]) for i = 1:size(A,2)])
rrule(::Type{Diagonal}, d) = Diagonal(d), (dD::Diagonal)->diag(dD)
rrule(::typeof(getindex), D::Diagonal, i,j) = D[i,j], (dx::Number)->Diagonal([i == j == k ? dx : zero(dx) for k = 1:size(D,1)])

If we tried to follow the diagonal-as-matrix rule, then we would get into trouble defining the rrule for the constructor:

# Diagonal-as-matrix rule
rrule(::typeof(f), ::Type{Tuple{Matrix}}, X::Diagonal) = A*X, Ȳ->A'*Ȳ
rrule(::Type{Diagonal}, d) = Diagonal(d), (dD::Matrix) -> # The output has to be a Vector, so what could we possibly do here?
rrule(::typeof(getindex), D::Diagonal, i,j) = D[i,j], (dx::Number)->[ii == i && jj == j ? dx : zero(dx) for ii = 1:size(D,1), jj = 1:size(D,2)]
oxinabox commented 4 years ago

I very much agree that it should be possible to split ChainRules into a few rules which define the behaviour and many other rules which provide performance improvements but which could be omitted without breaking any cod

This is undecidable without knowing which AD you are talking about. E.g. Tracker, Zygote, Nabla all have different limitations, about what they can AD through without custom rules. Such as Zygotes inability to handle mutation. Thus for each AD the "Instruction set" (as a I have been calling the list of things that they must have rules for) differs.

Further for many, providing rules only for Real arguements will work, but be so slow that it might as well not.

ettersi commented 4 years ago

I guess a more accurate formulation of the point I was trying to make is that we must be able to guarantee that the many rules we might define for a particular type are all compatible with one another, and one way (perhaps the only way?) to do that is to define the core rules and then require that any other rule must have the same signature as if some hypothetical AD tool had created the rule for us. I understand that different concrete AD tools have different capabilities, but I assume when their capabilities overlap they agree in what should be happening, and so I further assume that an ideal AD tool which can create all rules from the minimal core set is well defined.

Also, I am of course not advocating that ChainRules should not provide optimisation rules, just that these rules must be consistent with what we would get if they weren't there.

oxinabox commented 4 years ago

fair enough. That is a decent way to put it.

One problem is natural differential types. Like most structuredly sparse matrixes, or Periods for my pet DateTime examples.

As a rule an AD system can only make a structured differential types (i.e. Composite) but actually you really want to write rules using the Natural differential types

willtebbutt commented 4 years ago

but actually you really want to write rules using the Natural differential types

Have you written this example down anywhere @oxinabox ? Would be good to gather these ideas here. In particular I would really like a crystal clear argument of why we should prefer to work with natural differentials.

ettersi commented 4 years ago

Have you written this example down anywhere @oxinabox ? Would be good to gather these ideas here. In particular I would really like a crystal clear argument of why we should prefer to work with natural differentials.

The DateTime example is discussed here

As a rule an AD system can only make a structured differential types (i.e. Composite) but actually you really want to write rules using the Natural differential types

I believe this is a question of convention. We could either require that all types use Composite as their differential type, or we could allow each type to choose a suitable differential type but then require that they must be consistent about this. This does not break the "everything must be AD-able" rule, because it is up to us whether we interpret constructors and access functions as part of the core rule set to be defined by humans or something filled-in by AD.

Also, note that "differential type" in the above is meant in a duck-typing sense. It is okay for a type to choose Vector as its differential type and then return a SparseVector as a differential since the two types have compatible interfaces. But what would not be okay is for a type to sometimes use Vector and sometimes Composite as its differential, because then we would have to write all rules such that they can handle both Vector and Composite as input adjoints.

sethaxen commented 4 years ago

I keep getting sidetracked on writing this perspective up (it's long), so I'll just give the short version and apologize if it's unclear or inaccurate. I agree that ideally the right thing to do would be to write rules that do exactly what AD would have done without the rule, just more efficiently. Right now I really don't like the idea of only defining rules on functions of concrete types though. I implement a lot of custom array types, as do others, and I rely on rules defined on abstract arrays. Without them, e.g. using Zygote, I'm saddled with the (current) terrible performance of the getindex rules and iteration, and if I want better performance, I'd then need to reimplement a ton of LinearAlgebra rules that will be more or less identical to those for every other array type.

Treating abstract arrays as embedded in the arrays, it's safe to pull back dense array adjoints, and as long as 1) the initial adjoint was valid and 2) the rrule for the constructor is written as the composition of the constructor with the embedding into the arrays (e.g. see the rrule for Symmetric), and 3) the array is only ever treated as an array in the primal functions (so no accessing Symmetric(A).data, unless you add a custom rule to handle that like for the constructor), then everything should just work out.

But as @willtebbutt points out, you lose the time complexity of the primal function in the pullback. One way to partially get this back would be to "project" the output of each pullback to the predetermined differential type for its primal with a utility function. All rules with abstract types would do this.

So here's an example for *:

function rrule(::typeof(*), A::AbstractMatrix, B::AbstractMatrix)
    function times_pullback(Ȳ)
        ∂A = @thunk(project_cotangent(*, A, Ȳ * B'))
        ∂B = @thunk(project_cotangent(*, A, A' * Ȳ))
        return (NO_FIELDS, ∂A, ∂B)
    end
    return A * B, times_pullback
end

# defaults ignores function argument
project_cotangent(f, x, ∂x) = project_cotangent(x, ∂x)
# defaults to a no-op
project_cotangent(x, ∂x) = ∂x

# some possible projections
project_cotangent(x::Array, ∂x) = Array(∂x)
project_cotangent(x::Array, ∂x::Array) = convert(typeof(x), ∂x)
project_cotangent(x::Diagonal, ∂x) = Diagonal(∂x)
project_cotangent(x::LowerTriangular, ∂x) = LowerTriangular(∂x)
project_cotangent(x::UpperTriangular, ∂x) = UpperTriangular(∂x)
project_cotangent(x::Adjoint, ∂x::Adjoint) = ∂x
project_cotangent(x::Adjoint, ∂x) = Adjoint(project_cotangent(x.parent, adjoint(∂x)))
project_cotangent(x::Transpose, ∂x::Transpose) = ∂x
project_cotangent(x::Transpose, ∂x) = Transpose(project_cotangent(x.parent, transpose(∂x)))
project_cotangent(x::Symmetric, ∂x) = Symmetric(project_cotangent(x.data, symmetrize(∂x)), x.uplo)
project_cotangent(x::Hermitian, ∂x) = Hermitian(project_cotangent(x.data, hermitrize(∂x)), x.uplo)

symmetrize(x) = (x .+ transpose(x)) ./ 2
hermitrize(x) = (x .+ x') ./ 2

I don't think this guarantees the right time complexity, but by preserving type information that can then be used for dispatch internally within this and subsequent pullbacks, it will likely be more efficient than the pullback using dense arrays.

The downside of this approach is that if someone implements a new array type but doesn't define a rule for the constructor, then there's a good chance they won't be able to use AD with the array.

willtebbutt commented 4 years ago

I agree with you that we're going to have to make some compromises in the short- to medium- term here to not completely break Zygote. There are probably a few options, and I like the one that you suggest @sethaxen, particularly for matrices that are (often) otherwise dense-with-constraints e.g. the -Triangulars and Symmetric etc.

One thing that might be worth working out how custom rules for cases where the asymptotic complexity would otherwise take a hit e.g. Diagonals, and the proposed approach work together. Possibly they just do, or maybe it requires some more thought / additional methods? So if you return a Diagonal matrix to represent the cotangent of an UpperTriangular{T, Diagonal{T}}, the "right" thing happens and you don't take a complexity hit?

oxinabox commented 4 years ago

but actually you really want to write rules using the Natural differential types

Have you written this example down anywhere @oxinabox ? Would be good to gather these ideas here. In particular I would really like a crystal clear argument of why we should prefer to work with natural differentials.

At least some of it should be written down in: https://www.juliadiff.org/ChainRulesCore.jl/dev/design/many_differentials.html

The best example I have for why one wants to work with Natural Differentials, and not always with structural, is matrix factorizations. Which use Composite with properties corresponding the the properies of the factorizations. Not to the fields of the matrix factorizations. (I think in the design docs i was calling this semi-structural?) Because noone wants to deal with the fields of the matrix factorizations, because they are not nice to work with. (Probably getting the last bit of performance out would want to do that but thats another question).

For Cholesky: you want to work with the properties U or L (https://github.com/JuliaDiff/ChainRules.jl/blob/d3cd83e5d202475fbc18de8fadea69fd9042f66b/src/rulesets/LinearAlgebra/factorization.jl#L87-L100 I don't think you don't want to deal with the true field factors https://github.com/JuliaLang/julia/blob/110765a87af68120f2f9f4aa0bbc4054db491359/stdlib/LinearAlgebra/src/cholesky.jl#L119-

Though maybe I am wrong in this case, since the relationship of .factors to .U and .L is simple. https://github.com/JuliaLang/julia/blob/110765a87af68120f2f9f4aa0bbc4054db491359/stdlib/LinearAlgebra/src/cholesky.jl#L411-L423 and it would save the effort of having to define custom addition.

Better example is BunchKaufman factorization. The relationship between its fields and its properties is complicated. https://github.com/JuliaLang/julia/blob/master/stdlib/LinearAlgebra/src/bunchkaufman.jl#L285-L321 though here i don't want to deal with that one at all. Neither to define addition, not to define the frules / rrules. But if i must then i would rather do it seperately. So i can focus on calculus in the rule defination.

willtebbutt commented 4 years ago

Hmm well I think I'm probably okay with using differentials that aren't the Composite ones provided that they're guaranteed not to kill performance, which I think holds in this case.

I think this is probably the same reason that I think it's probably fine to represent the (co)tangent of a Matrix{Float64} with a Fill{Float64}, but not the other way around.

sethaxen commented 4 years ago

Related is https://github.com/JuliaDiff/ChainRulesCore.jl/issues/176. e.g. if Diagonal is an okay differential for LowerTriangular, but Matrix is not, then it follows that Int is an okay differential for Float64, but ComplexF64 is not.

willtebbutt commented 4 years ago

Ohhh so potentially it's generally the case that (if we adopt a strict convention) it's okay for a "bigger" type to have a differential represented by a "smaller" type, but not vice versa.

So

etc.

sethaxen commented 4 years ago

That's one idea. The thinking is that the (co)tangent should be (in some sense) embedded in the (co)tangent space of the primal, which could mean it is in the (co)tangent space of a submanifold that also contains the primal or could mean that it is just in a subset of the (co)tangent space of the same manifold.

I do think though that this is a trickier rule to enforce than the project_cotangent approach proposed above. Because whether the differential is a member of a superset or a subset of the same cotangent space, we can use dispatch to project it to our (co)tangent space, but only projecting if it is in a superset is trickier, because we can't in general differentiate the two for user-defined types.

ettersi commented 4 years ago

Here's another case study relating to this issue: https://discourse.julialang.org/t/43778

ettersi commented 4 years ago

I am not sure I understand the purpose of project_cotangent(). If A isa Matrix and B isa Diagonal, then project_cotangent(*, A, A' * Ȳ) involves computing the dense adjoint A' * Ȳ as an intermediate result, so we don't gain any performance by projecting this adjoint onto Diagonal. To gain performance, we need to merge rrule() and project_cotangent() such that no unnecessary intermediate entries are computed. So is the purpose of project_cotangent() to serve as a placeholder until an efficient implementation is provided?

willtebbutt commented 4 years ago

Yup, that's exactly it. It's a stop-gap to help out AD systems that don't handle mutation properly and so rely heavily on pullbacks (I'm really just looking at Zygote here tbh)

edit: ofc we would set things up such that you can always optimise stuff still, and we would in all cases we have the time / inclination to sort out.

sethaxen commented 4 years ago

Any given pullback runs the risk of the cotangent wandering away from the cotangent space of the primal. In general it will be the same cotangent, just embedded in a larger cotangent space. For example, in A * B, if A is real and B is complex, the output will be complex, the cotangent of the output will (probably) be complex, and a naive pullback will give A a complex cotangent. Without accounting for this, we should expect cotangents to wander further and further away from the cotangent space of the primals. For reals, they will tend to complex; for AbstractArrays, they will tend to Arrays.

This is problematic for 2 reasons: 1) the representation in the larger cotangent space will generally be less efficient, where operations have a worse time complexity than the primal. 2) as the program becomes larger and more complex with more user-defined rules, there's a higher chance that a pullback is encountered that makes assumptions about the cotangent (such as "the cotangent is real") that are violated by this cotangent vector. This can result in the computed gradient just being wrong.

project_cotangent is an imperfect solution. Frankly I don't like it. You could have the pullback closure call another function that dispatches on the inputs, outputs, and adjoint, but this is just a band-aid on the problem. A user can always implement a new structured array type for which a custom pullback would be needed to have the right time complexity. I'm coming around to the idea that rules for abstract types are a bad idea, but not as bad as utter failure for non-mutating AD.

ettersi commented 4 years ago

It's a stop-gap to help out AD systems that don't handle mutation properly

Why does mutation matter in this context? The Diagonal vs Matrix example does not involve mutation.

I'm coming around to the idea that rules for abstract types are a bad idea, but not as bad as utter failure for non-mutating AD

Isn't this exactly the same as with standard dispatch? For example, when you create a type MyMatrix <: AbstractMatrix, then there's a generic fallback for matrix products involving MyMatrix, but performance will likely suck, in particular if MyMatrix is something sparse like Diagonal. If you want performance, you have to bite the bullet and actually implement your optimised matrix product.

As far as I can tell, it's the same here. rrule(::typeof(*), A::AbstractMatrix, B::AbstractMatrix) = A*B, dC -> (NO_FIELDS, dC*B', A'*dC) is a reasonable fallback, but for many concrete types, this won't be performant enough so you have to buckle down and implement a specialised rrule.

mcabbott commented 4 years ago

One vote that project_cotangent(x, ∂x) sounds like a good idea. It seems unlikely that anyone will write all possible rules like rrule(::typeof(*), A::UpperTriangular, B::AbstractMatrix), and I doubt that most of them would be more efficient than projecting in any case -- N^2 BLAS vs 1/2 N^2 on some ragged shape.

Diagonal is the obvious exception where a special rule really will be much more efficient. But since we can't dispatch on ::SetDiff{AbstractMatrix, Diagonal} in order not to supply a rule at all, projecting the N^2 result of a naiive rule down to N seems like a decent way to return the right answer inefficiently until someone gets around to writing a rule. Maybe Symmetric / Hermitian are also cases where specialised rules could be faster.

Possibly it should be project_cotangent!!(typeof(x), ∂x), with license to mutate if possible & faster, and without keeping the whole x around if not otherwise needed.

Re number types, I'm not sure https://github.com/JuliaDiff/ChainRules.jl/issues/232#issuecomment-660298499 is quite right. This is perfect: "Float64 for ComplexF64 is fine, but ComplexF64 for Float64 is not". But for integers, promoting to float seems perfectly fine. If you ask for a derivative at 3, you are asking about infinitesimal changes away from 3, and should be unsurprised to get a real number. But if 3 is a local minimum of f, you would be surprised to be told that the downhill direction is 0 + im (because f happens to use complex numbers internally) which I think is what Zygote now does. It is reasonable (IMO) to insist that you ask the question with gradient(f, 3+0im) if you want a complex answer.

willtebbutt commented 4 years ago

Re number types, I'm not sure #232 (comment) is quite right. This is perfect: "Float64 for ComplexF64 is fine, but ComplexF64 for Float64 is not". But for integers, promoting to float seems perfectly fine. If you ask for a derivative at 3, you are asking about infinitesimal changes away from 3, and should be unsurprised to get a real number. But if 3 is a local minimum of f, you would be surprised to be told that the downhill direction is 0 + im (because f happens to use complex numbers internally) which I think is what Zygote now does. It is reasonable (IMO) to insist that you ask the question with gradient(f, 3+0im) if you want a complex answer.

I can certainly see your point; it seems perfectly reasonable to promote Integers to other subtypes of Real until you start thinking about consistency. To my mind "ComplexF64 for Float64 is not fine" implies that "Real for Integer is not fine", and "Matrix for Diagonal not is fine" and, more general, that "interpreting any particular type as being embedded inside another is, for the sake of AD, not fine".

As regards people's expectations / intuition -- it's taken the AD community in Julia a surprisingly long time to realise how much of an issue this all is, so I would expect some of the results to be surprising to users of AD.

I suspect that the reason it's taken this long for it to become obvious is that it was much less of an issue when we didn't have to be able to define appropriate tangents for any given type, back when Tracker / ReverseDiff were being written. All you had to do there was define them for Matrix and Float64 / Float32 -- technically the trackable objects were subtypes of AbstractMatrix and Real respectively, but the only way in which anyone used them in practice was as wrappers around Matrix{Float64} / Matrix{Float32} and Float64 / Float32, for which these issues simply don't arise.

Another Couple of Examples

I came across this issue a couple of time in the wild recently. I'm reporting on them here so that I don't forget about them.

Until recently, logdet was implemented for any AbstractMatrix in ChainRules which meant that e.g. whenever you asked Zygote to differentiate through logdet of a PDMat you hit this rule rather than the other one. The interesting thing in this case is that it will probably cause some numerical issues as the generic implementation involves explicitly constructing the inverse. This is something that the projection approach wouldn't be able to recover from.

We've also had an issue in KernelFunctions.jl whereby map is implemented too widely in Zygote. We had a couple of custom implementations of map for some special AbstractVectors for which the optimal thing to do would have been to naively differentiate through our implementation rather than utilise the rule in Zygote. In the end we would up implementing a function called _map to work around the issue, which is really a shame.

mcabbott commented 4 years ago

an issue in KernelFunctions.jl whereby map is implemented too widely in Zygote

I'm not sure I understand. You have a weird type <:AbstractVector for which Zygote defaults to something generic and slow. You would like to dispatch to a faster implementation. That seems very normal and very Julian. Are you saying that @adjoint Base.map(f::Transform, x::ColVecs) doesn't get called in your case? It seems to get called when I make a toy example, is this because something has changed? The more complicated code on master seems to have the same dispatch as the one you linked (9 May) but other changes.

logdet of a PDMat you hit this rule rather than the other one.

Does "other one" mean one defined specifically for :: PDMat? ("this" is the AbstractMatrix rule.) And again it sounds like dispatch ought to work, what am I missing?

mcabbott commented 4 years ago

Re integers, I guess the possible options are (1) they are categorical, all derivatives Zero(), or (2) their derivatives must be of (at most) the same type, or (3) they are points on the real line like any other, implicitly float unless the algebra happens to work out.

Zygote takes (3) right now. That's also the rule of every blackboard lecture on calculus ever. Option (1) would break for instance every example in Zygote's readme, which seems surprising. Under option (2), sin'(0)===1 is fine, while sin'(1) is an InexactError.

I don’t think a choice here implies rules for complex (or matrix) types. Integers really are different. You can do calculus over R, or over C (or worse), and over R^N etc, but you can’t do calculus over the integers. So either you don’t at all, (1), or you declare them to be real numbers which happen take up less blackboard space, (3).

I hadn’t seen #224, but it doesn’t do anything special for integers right? It won’t produce a complex gradient for the power in x^p unless you provide p::Complex. Which fits the general rule of not using a tangent space with more dimensions than the input.

willtebbutt commented 4 years ago

I'm not sure I understand. You have a weird type <:AbstractVector for which Zygote defaults to something generic and slow. You would like to dispatch to a faster implementation. That seems very normal and very Julian. Are you saying that @adjoint Base.map(f::Transform, x::ColVecs) doesn't get called in your case? It seems to get called when I make a toy example, is this because something has changed? The more complicated code on master seems to have the same dispatch as the one you linked (9 May) but other changes.

Zygote's machinery works as intended / custom @adjoints would work fine, the point is that we shouldn't have to define them in the first place.

The usual rationale around falling back to generic implementations when you've not got a specialised one in Julia does not apply to AD.

The reason that it's acceptable to fall back to a slow definition in general Julia code is because you've not written specialised code, so a slow fallback is better than not having anything. Going with the running example of a Diagonal matrix, if you had only implemented *(::AbstractMatrix, ::AbstractMatrix) and multiplied two Diagonal matrices together, you wouldn't expect it to automatically exploit the diagonal structure in the matrix, because you've not written the code to do that. That's fine.

The same is not true for AD. Suppose you have indeed written *(::Diagonal, ::Diagonal) and invoke * on a pair of Diagonal matrices. Cleary, the usual thing will happen and the code will be specialised to the appropriate method when executing the forwards pass. Now consider doing AD -- assuming that you've not written a rule that applies to *(::Diagonal, ::Diagonal), the AD tool will recurse into the definition of *(::Diagonal, ::Diagonal) and derive the appropriate rule in terms of the operations in which it is implemented.

If, on the other hand, one implements an adjoint for * on AbstractMatrixs, the AD tool is prevented from recursing into the code, as would be optimal.

This is analogous to what is happening here with map. Zygote has too-generic an @adjoint defined, and as a consequence wasn't able to recurse into the definition of map for a ColVecs object.

Does "other one" mean one defined specifically for :: PDMat? ("this" is the AbstractMatrix rule.) And again it sounds like dispatch ought to work, what am I missing?

Apologies, I definitely wasn't sufficiently clear before.

In ChainRules we used to have a rule with the following type signature

rrule(::typeof(logdet), X::AbstractMatrix) = ...

and in Zygote we have this rule

@adjoint logdet(C::Cholesky) = ...

PDMats defines the logdet of a PDMat here to be

LinearAlgebra.logdet(a::PDMat) = logdet(a.chol)

where a.chol isa Cholesky and PDMat <: AbstractMatrix. There are no custom rules specific to PDMats / AbstractPDMats.

The desired behaviour is that Zygote recurses into the logdet(::PDMat) method, sees the Cholesky factorisation, and invokes the @adjoint logdet(C::Cholesky) = .... Instead it was hitting logdet(::AbstractMatrix) and immediately calling rrule(::typeof(logdet), X::AbstractMatrix) rather than recursing a single step further and hitting the appropriate rule.

willtebbutt commented 4 years ago

Regarding your comment on the integers: I sympathise with your position on the matter, and agree that it would be quite jarring from a user's perspective (and mine, when I'm being a user).

It would certainly be possible for a given AD tool to take a different strategy from ChainRules at the user-facing level. For example, the default behaviour might be to convert Integers to Float64s before doing AD on a function, with an option to override this behaviour if a user really means Integer, not Float64.

I still believe that consistency has to be the goal of a rules system though.

I hadn’t seen #224, but it doesn’t do anything special for integers right? It won’t produce a complex gradient for the power in x^p unless you provide p::Complex. Which fits the general rule of not using a tangent space with more dimensions than the input.

Maybe @sethaxen can comment on this? I believe he wrote the code and understands the issue best.

mcabbott commented 4 years ago

agree that it would be quite jarring from a user's perspective

To be clear, you are arguing for what I called option (1), or option (2)?

And can you phrase exactly what you mean by consistency here? One way to phrase it is "use the tangent space implicitly defined by dual numbers". This has 1 real dimension in the case of both floats & integers, 2 real dimensions for ComplexF64. And N not N^2 real dimensions in the case of a Diagonal matrix.

willtebbutt commented 4 years ago

To be clear, you are arguing for what I called option (1), or option (2)?

Consistency suggests (2), but the implication of not allowing a float-valued tangent for an Integer is that you have to use something Integer-like as a tangent, which doesn't seem to make sense. To me this implies that (1) is the most appropriate thing to do.

And can you phrase exactly what you mean by consistency here? One way to phrase it is "use the tangent space implicitly defined by dual numbers". This has 1 real dimension in the case of both floats & integers, 2 real dimensions for ComplexF64. And N not N^2 real dimensions in the case of a Diagonal matrix.

I agree with the dimensionality aspect of what you're saying, and I agree with your characterisation of the dimensionality of the various tangent spaces defined my the dual numbers.

Although now that you're forcing me to be really specific (thanks for engaging with this, this is a helpful discussion) I think my point about consistency has been conflating two distinct issues:

  1. the dimensionality of the tangent space being that implicitly defined by dual numbers, and
  2. the type that you're building your vectors / tangents from. For example, Float64, Float32, Int64, Rational etc.

We seem to have some kind of consensus around point number 1.

Our discussions around point 2 is trickier though. There are clearly situations in which someone writes 1 and really means 1.0 or 1f0, but there are others in which they really mean 1 as in "a thing that I can count with" e.g. when writing a for-loop or indexing into an array.

I don't think that we have the ability to distinguish between the two in general other than from context. Seth's example with matrix powers feels like an example where it really matters, but I'm not completely sure.

mcabbott commented 4 years ago

Powers are tricky, but are the two problems obviously entangled? I wonder whether, in cases where previously Zygote returned a complex gradient for a real power, it shouldn't return something special to say "don't trust this gradient" which was often not needed anyway. Maybe a variant of Zero(), maybe just NaN? But perhaps I should think slower, and perhaps this long page isn't the right place.

Re dual numbers, if Dual(3, (true, false)) |> sin runs without error, does this imply we are in option (3)?

And, a corner case is whether a SparseMatrix has dimension N^2, or nnz(A). I'd argue for the former, but notice this:

julia> A = spzeros(3,3); A[2,3] = 1; ForwardDiff.gradient(sum, A)
3×3 SparseMatrixCSC{Float64,Int64} with 1 stored entry:
  [2, 3]  =  1.0

Re types in general, should x::Float32 be permitted gradient dx::Float64? Here pure maths is no help at all, obviously. But this promotion appears to be a common performance bug for Flux, as one x .* 1/size(x,2) will route everything to generic_matmul.

ettersi commented 4 years ago

Thanks for explaining why methods come with additional headache in AD as compared to "standard" Julia, @willtebbutt ! This issue has been mentioned in a number of discussions in this repo, and I never quite understood what the problem was, but now I do.

However, I am not sure I agree with the proposed solution. The fundamental problem is that we have two different axes along which we would like the compiler / AD tool to auto-generate code for us, namely the specialisation and differentiation axes. What you suggest in your post is that specialisation should take precedence over differentiation, but I am not sure that this is the right thing to do in all circumstances. Maybe this situation is the AD equivalent of an ambiguity error which needs to be manually resolved by a human?

willtebbutt commented 4 years ago

The above issue is yet another example of this issue cropping up in practice. It's an interesting example that feels morally a bit different from the other examples here in that it's a rule that doing quite high-level stuff.

A couple of other misc. thoughts on this, in no particular order:

Not sure how helpful either of these are, but maybe food for thought 🤷

willtebbutt commented 3 years ago

We have another contender! https://github.com/FluxML/Zygote.jl/issues/899

mcabbott commented 3 years ago

I think the “type constraints” label is rolling together several levels of too-broad definitions:

This last class is the main concern here:

When you manually define an rrule for a more generic method of said function, said [rrule] method takes precedence over the more specialised method [implicitly defined by the forward code. ...] So one fix for this would be to figure out how to ensure that [...] a really generic rrule, it doesn't override the rrules automatically derived by more specific functions.

I don't think I appreciated this before, but the design of ChainRules is such that you can disable a generic rrule, when you know that the more specialised method you just defined is AD-friendly. It doesn’t quite work yet to define a specific rrule method that returns nothing, because its Zygote interaction checks the method table, but perhaps this can be fixed: https://github.com/FluxML/Zygote.jl/pull/967 is a start.

oxinabox commented 3 years ago

It doesn’t quite work yet to define a specific rrule method that returns nothing, because its Zygote interaction checks the method table, but perhaps this can be fixed: FluxML/Zygote.jl#967 is a start.

Care also needs to be taken that this mechanism works for Nabla, and Yota. (I have been told it already works for Diffractor, abd apparently it works for Diffractor even if it isn't inferrable to Compiler.return_type.)

mcabbott commented 3 years ago

Oh nice, I didn't know Yota was on board, since https://github.com/dfdx/Yota.jl/pull/85 it seems. And it looks like https://github.com/invenia/Nabla.jl/pull/189 is the Nabla PR. Both look to be about as tricky as Zygote, i.e. they aren't just calling the rrule, which would automatically work.

mzgubic commented 3 years ago

closed by https://github.com/JuliaDiff/ChainRulesCore.jl/pull/385