Open willtebbutt opened 3 years ago
It's pretty easy to make Zygote give you errors from natural + structural, for instance:
julia> gradient(x -> (real(x) + x.im), 1+im)
ERROR: MethodError: no method matching +(::NamedTuple{(:re, :im), Tuple{Nothing, Int64}}, ::Int64)
julia> gradient(x -> (sum(x) + last(x)), 1:3)
ERROR: MethodError: no method matching +(::NamedTuple{(:start, :stop), Tuple{Nothing, Int64}}, ::FillArrays.Fill{Int64, 1, Tuple{Base.OneTo{Int64}}})
These don't seem to be very common in the wild? In these two cases, I think the rule for getproperty
should call a ProjectTo which standardises on the natural. But in general, it seems tricky.
These don't seem to be very common in the wild?
Not super common, but I don't think I've worked on a project where I've not encountered them (for example, if you do enough things with a Diagonal
, you will probably encounter at some point), and I usually wound up type-pirating to get around them.
In these two cases, I think the rule for getproperty should call a ProjectTo which standardises on the natural.
Again, not completely sure that I agree re. Fill
s and naturals for the reasons laid out above (the natural for a Fill
is deeply unintuitive to me, whereas the structural is simple). For example, could / should @no_opt
most things in FillArrays
, including sum
, and just get the structural.
edit: I meant @opt_out
, not @no_opt
.
It's pretty easy to make Zygote give you errors from natural + structural
For these cases one option is overloading +(::Tangent{FooArray}, ::AbstractArray})
which could then do either a projection onto the natural, or a structuralization onto the Tangent.
Or potentially even a form of thinking the just delays til we go to accumulate against a primal (Keno was talking about that a whole ago)
Possibly this should be a seperate issue
For these cases one option is overloading +(::Tangent{FooArray}, ::AbstractArray})
Certainly -- this kind of code has to be written if we're ever allowing naturals and structurals to interact. I think there's a whole separate issue to be had about the points in the code at which we want the structural and at which points we want the natural.
For example, I think it's reasonable to assert that we often want the natural inside rules, as evidenced by people's experience over the last few months (not always though). Conversely, constructors like structurals, so it seems possible that we want to ensure that naturals don't escape rules, or something like that.
Possibly this should be a seperate issue
Yes.
In these two cases,
not completely sure that I agree re. Fills and naturals for the reasons laid out above (the natural for a Fill
Then you're looking at different examples, or talking about generalities. My specific two examples do not contain any tangent for an input x::Fill
, natural or otherwise.
But if you don't find the natural for x::Fill
intuitive, the natural gradient for a range will be worse.
However, the natural gradient for x::Complex
is not so bad. So there are some struct
s for which we always want to normalise to normal. Notice also that in my example, it does not hit +(::Tangent{Foo}, ::Foo}
-- the natural tangent from the other path lies in a subspace which isn't a subtype.
I think that's clearly what we want for SArrays too. Handling a Tangent{SVector}(; data=Tuple(...))
seems strictly inferior. In fact accessing its fields directly (in the forward code) seems extremely odd here, these are arrays. (Although obviously people do odd things in the wild & we should try to build guardrails, where possible.)
My specific two examples do not contain any tangent for an input x::Fill, natural or otherwise.
Wow, sorry, I completely misread your example. You're right that it's painful working with these types. This is why I prefer the structural tangent for them, although I also agree that it's not always what you need.
I think that's clearly what we want for SArrays too. Handling a Tangent{SVector}(; data=Tuple(...)) seems strictly inferior.
I agree in the specific case of SArray
s that we probably want them to be primitive.
In fact accessing its fields directly (in the forward code) seems extremely odd here, these are arrays. (Although obviously people do odd things in the wild & we should try to build guardrails, where possible.)
This I disagree with, in the general case. Array authors access fields all the time when writing methods optimised for their arrays, Fill
arrays being a good example (getindex
is essentially just getfield
here). It's simply not odd to want to access fields of a structured array.
Again, that paragraph you are quoting from is specifically about SArrays, not generalities. I'm not sure I have ever seen anyone access the field by name; I had no idea what its name was. They are used as arrays in user code, and their interface with the outside world is through the usual AbstractArray interface. What the package authors do behind closed doors is, of course, nobody else's business.
I agree that what to do with other types is more tricky. Types that straddle the border between being containers and being arrays, I guess. Although I suppose I'd still like better examples of how these get used, to ground discussion. For instance field access is not really part of the API of Diagonal, and even less so for Symmetric (since it gets you meaningless data) but perhaps it is for some other weird type.
Again, that paragraph you are quoting from is specifically about SArrays, not generalities. I'm not sure I have ever seen anyone access the field by name; I had no idea what its name was. They are used as arrays in user code, and their interface with the outside world is through the usual AbstractArray interface. What the package authors do behind closed doors is, of course, nobody else's business.
Cool. Apologies for generalising! Yeah, we're definitely in agreement that users probably shouldn't be accessing SArray
s via getfield
, that just seems weird. That being said, if we implement getfield
properly, it should just be fine 🤷 (if ill-advised).
Although I suppose I'd still like better examples of how these get used, to ground discussion.
Well the ColVecs
/ RowVecs
/ Fill
/ WoodburyPDMat
are all examples intended for this purpose.
ColVecs
/ RowVecs
/ WoodburyPDMat
are of the "don't ever give me a natural tangent" flavour -- they're entirely written with a small set of AD-friendly operations in mind, so we have no need to write rrule
s for them, except perhaps to @non_differentiable
some things here and there, and @opt_out
of any generic rules that we don't like (although I don't think we've encountered any yet fortunately -- might change as people start writing more generically-typed rules though).
Fill
arrays are a bit different -- pretty much every function they provide specialised implementations for should produce a structural tangent just fine without writing any rrules (e.g. sum
, broadcast
, map
etc), but there are probably others that don't have specialised implementations and hit generic fallbacks, and therefore need the fallback rules + projection mechanism. So Fill
is a good example where we should expect a mixture of different tangent types to appear during the reverse-pass of AD.
Are these examples missing something that you're interested in seeing?
For instance field access is not really part of the API of Diagonal, and even less so for Symmetric (since it gets you meaningless data) but perhaps it is for some other weird type.
I agree that field access is typically not part of the user-facing API of an array (although I also agree that it might be in some weird instances). However, AD ought to just work on functions written by library authors if they write code with AD in mind (library authors shouldn't have to write rrules if they write their code in an AD-friendly manner).
That being said, if they also want their types to work with the generic fallbacks defined for AbstractArray
s, they'll need to write ProjectTo
, and functionality to convert natural tangents into structural tangents so that one can AD through the constructors (unless they make ProjectTo
return a structural, in which case the natural doesn't exist. Not sure what kinds of problems that can cause though -- I imagine it preceludes the use of some generic rrule
s that they might be interested in. As you say though, examples of this would help the discussion).
As regards people accessing internal fields of types when they probably shouldn't be, I agree that they probably shouldn't be. However, we do know how to AD through ill-advised field access, and we know how to represent the tangent. To my mind, from an AD's perspective, there's not really a distinction between approved field access by type authors, and ill-advised field access by misguided users.
Taking the Diagonal
example, if someone writes
foo(D::Diagonal) = sum(D.diag)
we can AD it with no problem provided that the rule system doesn't somehow get in the way. If the standard library changes the internals of Diagonal
and breaks a user's code, then of course the user is entirely to blame.
examples of how these get used
Are these examples missing something
Yes, thanks for the types, but what I meant was whether there are closer to end-to-end examples which give errors / wrong results / are inefficient. If ColVecs / RowVecs are glorified NamedTuples, which are always treated as structs, always Tangent, then there is no problem right?
The tricky cases seem to be (1) where some generic rule passes a special type along, and hence its generic pullback probably expects an AbstractArray, and (2) where the same object is used twice, and one path treats it as an array, the other as a struct, hence +
fails. My explicit examples were of (2).
The simplest case I can picture for (1) is this, but I don't know why it isn't an error:
julia> gradient((x,y) -> (x*y).re, 1+im, 2)
((re = 2, im = nothing), (re = 1 - 1im, im = nothing))
Of course, this one should be avoided in multiple ways: You should write real()
not .re
to be safe. The getfield of Complex should not construct a Tangent/NamedTuple, since Complex is clearly simple enough to always be its own gradient; Maybe that needs a special overload as you suggest for SArray, or maybe the rule for getproperty
just calls ProjectTo which handles it.
But there might be others which are less easy to avoid -- what are the simplest such? For Fill
I can invent things if I use field access. But, the entire point of Fill is that it's a cheap thing to pass into generic array code which does not know that it's this type... if you knew this all the way through, you would use a number... so this seems like cheating:
julia> gradient(x -> sum(x .+ 10), Fill(3,3))
(3-element Fill{Int64}: entries equal to 1,)
julia> gradient(x -> (x .+ 10).value, Fill(3,3)) # generic unbroadcast does not handle Tangent/NamedTuple
ERROR: MethodError: no method matching size(::NamedTuple{(:value, :axes), Tuple{Int64, Nothing}})
Stacktrace:
[1] unbroadcast(x::Fill{Int64, 1, Tuple{Base.OneTo{Int64}}}, x̄::NamedTuple{(:value, :axes), Tuple{Int64, Nothing}})
@ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/broadcast.jl:51
If ColVecs / RowVecs are glorified NamedTuples, which are always treated as structs, always Tangent, then there is no problem right?
You're correct. Again, lack of clarity on my part -- with these kinds of things I was just trying to assemble some concrete examples of things that do indeed work just fine, so that I can point to them if anyone is ever confused about the need for structural tangents for AbstractArrays.
The tricky cases seem to be (1) where some generic rule passes a special type along, and hence its generic pullback probably expects an AbstractArray, and (2) where the same object is used twice, and one path treats it as an array, the other as a struct, hence + fails. My explicit examples were of (2).
Agreed. The options for resolving case 2 are clearer to me than 1. Either, you must (I think)
structural + natural
ever occurs, something sensible happens, and you must define a constructor to accept the natural, orI quite like the second approach because, unless I'm mistaken, it would make explicit what I think we're currently doing implicitly in a bunch of rrules for constructors (conversion from natural to structural), for example Diagonal.
I'm struggling to come up with good examples for (1). Any thoughts?
But, the entire point of Fill is that it's a cheap thing to pass into generic array code which does not know that it's this type... if you knew this all the way through, you would use a number...
This not always the case. My favourite example of this is in AbstractGPs
. The third field of our FiniteGP
type is always a covariance matrix (I'm not sure why we don't make it subtype AbstractMatrix
now that I read the code again). Very often, it will be either
Diagonal(some_Vector_of_Reals)
Diagonal(Fill(s, N))
for some real number s
and integer N
, and some_Vector_of_Reals<:AbstractVector{<:Real}
. A thing we rely on is specialising based on whether or not this matrix is Diagonal
or a Diagonal{<:Real, <:Fill}
, beause there are often entirely different algorithms that we can use in those cases. TemporalGPs relies on this entirely (because efficient inference isn't possible this is matrix isn't diagonal), and OILMMs.jl
(which is currently being reworked into a more general package) only works when this is a Diagonal{<:Real, <:Fill}
.
We need it to be an AbstractMatrix{<:Real}
so that we can be sure it can be added to another matrix of the same size (wouldn't work for scalars or vectors). Generally speaking, if we hit this code path we couldn't care less what kind of matrix it is.
So I think the way in which this differs from what you're presuming is that some of the time we couldn't care less whether we have a Fill
or not, but other bits of the time we really do.
But there might be others which are less easy to avoid -- what are the simplest such?
I think the Fill
example is a good one. It can be made to feel a bit more realistic by opting out of sum
by defining my_sum
(if I'm not mistaken Zygote defines a rule for sum
). Something like
julia> using FillArrays
julia> using Zygote
julia> using FillArrays: AbstractFill
julia> # Work around a Zygote performance bug by using Zygote.literal_getfield manually.
my_sum(x::AbstractFill) = Zygote.literal_getfield(x, Val(:value)) * length(x)
my_sum (generic function with 1 method)
julia> Zygote.gradient(my_sum, Fill(5.0, 10))
((value = 10.0, axes = nothing),)
julia> Zygote.gradient(x -> my_sum(x .+ 10), Fill(5.0, 10))
ERROR: MethodError: no method matching size(::NamedTuple{(:value, :axes), Tuple{Float64, Nothing}})
we insist that natural never escape pullbacks
I'm not sure what this means exactly. By "natural tangent" we mean something returned by a pullback which isa AbstractArray
, no? (Or isa Number
) What happens inside the function we don't care.
What I'm hoping to elucidate with examples is whether any types actually have to straddle the border between array-like and struct-like. I don't see why we shouldn't treat Fill as array-like, 100% of the time. Like Complex should be Numer-like, 100% of the time. Conversion is free. Whereas it sounds like ColVecs / RowVecs are like NamedTuples, structural 100% of the time --- maybe they are what you mean by "never escapes"? Can we partition all types into one category or the other?
I'm not sure what this means exactly. By "natural tangent" we mean something returned by a pullback which isa AbstractArray, no? (Or isa Number)
Yeah, more or less. My working definition is (excluding primals like Array
and Float64
etc) that a natural tangent (/ differential, I've stopped calling them differentials) is anything that's not a Tangent
, or one of the other tangent types in CRC (ZeroTangent
, NoTangent
, etc).
What happens inside the function we don't care.
Definitely -- modulo it returning acceptable types.
What I'm hoping to elucidate with examples is whether any types actually have to straddle the border between array-like and struct-like.
Symmetric
/ Diagonal
have to straddle, because they can wrap arbitrary AbstractMatrix
s / AbstractVector
s respectively, and said AbstractMatrix
/ AbstractVector
might be a container-like AbstractArray
, meaning that it only has a structural tangent, which can't be put inside a Symmetric
/ Diagonal
.
e.g. the tangent of a
Symmetric{<:Real, <:WoodburyPDMat}
would have to be a Tangent
, because the tangent of a WoodburyPDMat
must be a Tangent
.
I don't see why we shouldn't treat Fill as array-like, 100% of the time.
This is a fair point -- I also can't see any reason, beyond my lingering doubts about the properties of using it as its own tangent (I'm concerned that sometimes it needs to behave like a struct and other times like an array, and maybe that leads to inconsistencies. I'm planning on trying to convince myself one way or the other over the next couple of days).
maybe they are what you mean by "never escapes"?
By "never escapes", I mean that a pullback always returns a Tangent
(or ZeroTangent
etc).
My understanding is that natural tangents are primarily a way to
This suggests to me that they're important in exactly two places:
Outside of these two situations, it's potentially a different matter. For example, inside the AD pipeline, we could require all tangents to be structural (or ZeroTangent
etc), which would then be simple to deal with (+
always well defined, default constructor rrules happy). We could make this easy for users to do by requiring project_to
to instead return the structural tangent (presumably a breaking change, not possible in 1.x, I assume?), and providing another function to do the structural -> natural mapping (Tangent
to Symmetric
, for example, when it makes sense to do, and erroring if the particular Tangent
doesn't have a corresponding natural tangent). This could be done at the start of a pullback
definition, if the user would rather have e.g. an AbstractArray
than a Tangent
(could also unthunk
automatically etc).
My suspicion is that this kind of change would be minimally invasive, and would have the benfit of localising errors to the inside of rrule
s (instead of not being able to +
, conversion to / from natural tangent might fail. The former case tends to happen in the middle of an AD pipeline, while the latter should happen inside an rrule that a human wrote. Consequently, the latter is probably easier to understand and debug). It would also make the rule-writer's expectations about the types they're going to get be really explicit. By asking for the natural tangent, you're saying "I want something array-like, on which I can perform the usual array operations" (in the context of tangents for AbstractArrays).
For example, the rule for *(::AbstractVecOrMat, ::AbstractVecOrMat)
would become something like
function rrule(
::typeof(*),
A::AbstractVecOrMat{<:CommutativeMulNumber},
B::AbstractVecOrMat{<:CommutativeMulNumber},
)
Y = A * B
project_Y = ProjectToNatural(Y)
project_A = ProjectTo(A)
project_B = ProjectTo(B)
function times_pullback(ȳ)
Ȳ = project_Y(ȳ)
dA = @thunk(project_A(Ȳ * B'))
dB = @thunk(project_B(A' * Ȳ))
return NoTangent(), dA, dB
end
return Y, times_pullback
end
edit: re FillArrays
, if you have an operation which AD successfully derives, it'll presumably expect a Tangent
as the tangent for its output on the reverse-pass. Trying to find a good example now.
edit2: no, with getfield
implemented appropriately, this should be fine.
edit3: break up long sentence.
By "never escapes", I mean that a pullback always returns a Tangent
You mean every pullback always returns a Tangent, I think. (Or maybe not for Complex
.) I'm going to dub this the nuclear option. I think it means that arrays unknown to the AD system can never propagate through AD, as arrays. They all need special reconstructor functions before they can be used within generic rules. It can't "just work" with SArrays & OffsetArrays.
The "not straddling" idea is I think to tie the behaviour to type of x
. Perhaps, for some types, all pullbacks always return a Tangent, while for others, no pullback ever returns such. The type of x::Symmetric
contains the type of parent(x)
, maybe that counts. If every x
is on one side of the fence or the other, then my problem (2) never occurs, +
is never confused.
I still think we need better examples of problem (1). That is, sufficiently weird array types not to have easy natural tangents, which are nevertheless the result of the forward pass of some generic function, whose generic rrule we might hope to use. If the construction of a natural tangent is possible but expensive, then you can see some virtue to delaying it until certainly necessary. If the construction is impossible, then you cannot use the generic rrule under any scheme, so we're not in problem (1) anymore. (You will either need to write another, or opt-out.)
I'm going to dub this the nuclear option.
Haha fair enough. I would prefer it to be characterised as the "clean" option, in that it cleanly separates the world of the users / rule-implementers from the world of the AD.
I think it means that arrays unknown to the AD system can never propagate through AD, as arrays.
Is this a problem? AD systems like structural tangents -- that's their "native" representation of any struct
, and by extension their native representation of any AbstractArray
that isn't a primitive.
They all need special reconstructor functions before they can be used within generic rules.
I'm not convinced we don't already have to do this somewhere, it's just unclear where. For example, we already have to write code to map from the natural to the structural inside the a hand-written rrule for the constructor.
How is it that an AbstractArray
tangent comes into existence in the first place if not via conversion from a structural tangent? (structural tangents being the only thing that can be generated automatically).
The "not straddling" idea is I think to tie the behaviour to type of x. Perhaps, for some types, all pullbacks always return a Tangent, while for others, no pullback ever returns such. The type of x::Symmetric contains the type of parent(x), maybe that counts. If every x is on one side of the fence or the other, then my problem (2) never occurs, + is never confused.
Sorry, I'm really struggling to follow this para. Could you elaborate?
I still think we need better examples of problem (1).
Is the Fill
example not sufficient? It has the key properties that we care about, does it not?
edit: ignore my last comment. Reading your para again.
How is it that an AbstractArray tangent comes into existence in the first place if not via conversion from a structural
The generic rule for sum(x)
would be something like dx = similar(x).=1
. Which will produce a CuArray, or a StaticArray, or an OffsetArray, etc. Many are like this.
we already have to write code to map from the natural to the structural inside the a hand-written rrule for the constructor
Yes, usually, this is why my (2) is the easy problem. But the reconstructor is the map in the other direction.
Fill's "reconstructor" is trivial and free. The simplest Symmetric's is easy but not necessarily cheap. I am not sure whether this is possible at all for some of the weird wrapper combinations you mention. But I also don't know whether those can plausibly be produced by a generic rule as specified. If they cannot, then problem (1) does not occur.
Problem (2) is about two tangents for the same x
; if dx
-always-natural vs. dx
-always-structural was determined by typeof(x)
(and enforced somehow, by projectors) then it is almost a tautology that they will always agree. How easily this could be implemented would need more thought.
Addressing your previous comment:
sufficiently weird array types not to have easy natural tangents
If we try to apply a generic rrule to a type which doesn't have a natural tangent defined, then I agree that you couldn't ever hope to use generic rrules, and that's fine as far as I'm concerned.
To my mind the more troublesome case is that in which an AbstractArray
could make use of a generic rrule, but in our current set up is unable to because at no point in our rrules do we say "give me the natural tangent". As things stand, I think the problem will arrise whenever we have the following structure:
foo(x::AbstractArray) = _hard_code_to_differentiate_so_has_generic_rrule
bar(x::Fill) = easy_to_differentiate_code_without_rrule
Forwards pass:
a = Fill(5.0, 2)
b = foo(a) # b isa Fill
c = bar(b)
Reverse-pass:
Db = pullback_bar(Dc) # Db isa Tangent
Da = pullback_foo(Db) # hits a generic rrule and breaks
Of course, you could replace Fill
with anything that's not a primitive.
A concrete example from the stdlib:
using LinearAlgebra
using Zygote
foo(x) = 5x
bar = parent
b, pullback_foo = Zygote.pullback(foo, Symmetric(randn(5, 5)))
c, pullback_bar = Zygote.pullback(bar, b)
Dc = randn(5, 5)
Db = pullback_bar(Dc)[1]
Db isa NamedTuple # structural tangent
pullback_foo(Db) # errors when it hits https://github.com/JuliaDiff/ChainRules.jl/blob/555ac11cf25c03e9dbb8dcd9d431bc8ff11349ee/src/rulesets/Base/arraymath.jl#L93
Uses Zygote
types, but the point is clear. You could replace foo
with any of the matrix functions whose rrule
s are defined here, ^, or inv, and get the same kind of problem.
This example also works with parent
and Diagonal
, when foo
is:
muladd(::Diagonal, ::Diagonal, ::Diagonal)
(slightly different error, so I'm assuming we have a generic rrule for something inside muladd
)-(::Diagonal)
-(::Diagonal, ::Diagonal)
real(::Diagonal)
imag(::Diagonal)
I tried the same thing for parent
and Adjoint
, and got an Adjoint
back for Db
. This means that someone has written a rule for this method. It turns out that it's in Zygote, so probably Mike or I wrote it at some point long ago. This rule will break if anyone ever wants the Adjoint
of a container-like AbstractArray
that only has a structural tangent. Even if this weren't the case, we just shouldn't need to write a rule for such a simple function.
Addressing your most recent comment:
The generic rule for sum(x) would be something like dx = similar(x).=1.
I see -- this seems like a reasonable heuristic. I guess we have to be careful about doing things like assuming mutability though e.g. in the StaticArray
s case.
But the reconstructor is the map in the other direction.
Indeed. I was just pointing out that the map is already written in that direction. My guess would be that writing it in the other direction given that code will be very straightforward. It also appears necessary to make scenario (1) (eg. the Symmetric
example above) work in general.
The simplest Symmetric's is easy but not necessarily cheap
Agreed.
am not sure whether this is possible at all for some of the weird wrapper combinations you mention. But I also don't know whether those can plausibly be produced by a generic rule as specified. If they cannot, then problem (1) does not occur.
Yeah -- it's not going to be possible for AbstractArray
s without a natural-structural mapping to participate in the generic rrules, and that seems acceptable to me.
It's definitely plausible that such an array could be output from an operation with a generic rrule. It would happen if the array's author implemented a generic operation (e.g. *(::Number, ::AbstractArray)
) in a non-differentiable manner (if they'd implemented it in a differentiable manner, they should @opt_out
), but which outputs their type.
I agree it seems hard to see how it could happen if the array author relies on fallback code somewhere.
I'm thinking again about my proposal from a few comments ago, and how the minimal number of translations between natural and structural tangents could be achieved.
Consider again *
:
function rrule(
::typeof(*),
A::AbstractVecOrMat{<:CommutativeMulNumber},
B::AbstractVecOrMat{<:CommutativeMulNumber},
)
Y = A * B
project_Y = ProjectToNatural(Y)
project_A = ProjectTo(A)
project_B = ProjectTo(B)
function times_pullback(ȳ)
Ȳ = project_Y(ȳ)
dA = @thunk(project_A(Ȳ * B'))
dB = @thunk(project_B(A' * Ȳ))
return NoTangent(), dA, dB
end
return Y, times_pullback
end
Let's assume that we allow naturals and structurals to float around anywhere they like, so not the non-escaping situation I discussed.
Further, let's keep ProjectTo
s semantics the same (yay, no breaking change).
Further assume that we have access to a function get_structural(tangent)
which is the identity if the argument is a Tangent
, produces the structural if it's a natural tangent (I'm not sure such a function could be implemented precisely like this, but bear with for now).
You could then easily define +
between different tangent representations:
+(t::Tangent, s::Any) = t + get_structural(s)
+(t::Any, s::Tangent) = get_structural(t) + s
Natural tangents have to have +
defined, so the natural + natural
case isn't a problem.
In the above scheme, conversions between natural and structural would only happen if an rrule
author explicitly requests the structural, or if a structural and natural interact, which seems like it ought to achieve the minimum possible number of conversions.
I think you'd just need to put a call to get_structural
inside Zygote's / Diffractor's constructor rules (since constructors don't know how to handle naturals), and you'd be grand (everything would work).
edit: my only concern with this scheme is the plausibility of implementing get_structural
in the absence of access to the primal. I'm pretty sure that you'd need to know which primal you're dealing with in order to implement it, for the same reasons that you do in order to implement ProjectTo
.
One other point: the fact that situation (1) occurs means that we need to write structural -> natural mappings anyway. I don't think there's any escaping it. At the moment we just have code that we can't AD 🤷
This is a lot of text. I think that all examples with Fill are solved by
(project::ProjectTo{Fill})(dx::Tangent) = Fill(dx.value / prod(length, project.axes), project.axes)
matching what already happens for Complex
julia> ProjectTo(1+im)(Tangent{Complex}(; re=33))
33.0 + 0.0im
plus having the gradient of getproperty
apply projectors.
The example for (1) so far is
gradient(abs ∘ first ∘ parent ∘ log, Symmetric(rand(3,3))) # error, adjoint(::Tangent)
in which log
(from your list, less trivial than 5*) is the function with a generic rule, which propagates Symmetric
.
One objection here is that parent
is just another way of writing getproperty
. It's only safe to call if you know precisely the type already -- a weird mix of generic and specific. (And first
keeps an element we know not to be junk, sum
would not. Which means parent
is doing nothing.)
For this, it sounds like the projector for Symmetric should reconstruct from tangents. Maybe that's more expensive than ideal, but cheap compared to log
.
Zygote already does this for Adjoint, where it's free. But:
will break if anyone ever wants the Adjoint of a container-like AbstractArray that only has a structural tangent
Maybe. The real problem (1) is when this combination is also produced by the forward pass of a generic rule. That's what I think we still lack examples of.
For problem (2), I'm not sure that calling to_structural on the mismatched natural is always going to work. We allow (some) subspaces which aren't subtypes. If you want to add Tangent{UpperTriangular}(...) + Diagonal(...)
, turning the 2nd into Tangent{Diagonal}(...)
won't help. Although I don't have a good example that would produce this.
Maybe. The real problem (1) is when this combination is also produced by the forward pass of a generic rule. That's what I think we still lack examples of.
I stand by the above examples, and believe they are sufficient. To my mind it is enough that someone could do this, and there's no reasonable grounds on which we could object to someone doing this, in order for us to need to cover it. I'm happy to produce a MWE, but I don't believe there's a need to go hunting in the wild.
I agree that a line has to be drawn somewhere (e.g. we can agree that strings are non-differentiable, even though though someone could in principle write a differentiable programme in which data passes through them), but I can find no such objection in terms of the existing examples.
One objection here is that parent is just another way of writing getproperty. It's only safe to call if you know precisely the type already -- a weird mix of generic and specific.
This example is analogous to the Fill
example I gave before. Sometimes you want to specialise, other times you need to be generic, and they could be in the same programme.
This is besides the point though. The example in question is perfectly legal differentiable Julia code, and it doesn't involve mutation (because we have a rule for log
). To my mind that is sufficient, but it's also part of the public API of a standard library. We have to cover this kind of thing by default, otherwise we get in the way of AD doing AD.
This is a lot of text. I think that all examples with Fill are solved by
I think this is correct, in this particular case, because Fill
could plausibly be a primitive (I think. Again, given our lack of definitions and agreed-upon consistency criteria, I'm not going to say that I fully understand this case, and therefore I'm not 100% certain that this is a viable solution).
I wonder whether this is desirable though (making things primitives in general).
Suppose that the set of things that we ultimately decide that we need to do in order to support structural-to-natural mapping / projection solve all of the problems we're encountering, and work for all AbstractArray
s. In that case, we could just do those things for Fill
s, SArray
s etc, and not bother making them primitives. As far as I can see, we wouldn't lose anything by doing so, but we would gain simplicity -- we could provide a single clear set of instructions to array authors that want their types to play nicely with generic rrule
s, rather than giving them multiple options. This would also make our lives easier (fewer things to maintain / test / understand / define).
Phrased differently, making things primitives doesn't solve the wider problem. It only works in a subset of the cases that we care about, and we need to solve the more general problem anyway. If the more general solution isn't any more complicated than the specific solution, we only need the more general one.
(I realise this is back-tracking a bit on my earlier position regarding making SArray
a primitive, but I can't think of any compelling advantages to doing so if we solve the problem more generally).
If you want to add Tangent{UpperTriangular}(...) + Diagonal(...), turning the 2nd into Tangent{Diagonal}(...) won't help.
That's an interesting example. I've only implemented a single to_structural
method for each primal type, but I definitely believe that you could have multiple valid naturals. So multiple methods of to_structural
are probably required for each primal. For example
to_structural(b::Bijection{<:UpperTriangular}, n::Diagonal) = ...
seems perfectly reasonable to me.
That being said, say you got a Diagonal
natural for an UpperTriangular
as you suggest. What should ProjectTo
do to the Diagonal
? Should it produce an UpperTriangular
, or leave it as a Diagonal
?
What should ProjectTo do
Right now it explicitly allows this, but I'm not sure much is gained:
https://github.com/JuliaDiff/ChainRulesCore.jl/blob/master/src/projection.jl#L330-L337
Without that, you'd get UpperTriangular(Diagonal(...))
.
A sketch of a larger machine for such things: #390
Proposal to return Diagonal(Fill(...))
from tr
: https://github.com/JuliaDiff/ChainRules.jl/issues/46#issuecomment-620707731 . The projector for a full matrix will certainly allow a Diagonal gradient through.
are sufficient. To my mind it is enough that someone could do this, [...] don't believe there's a need to go hunting in the wild.
Re (1), I'm dubious that we can design a good general system without quite a few examples which actually need it. Maybe they have other weird properties which neither of us have thought of, nature is more creative than us. Heard about the zoo-keeper who scaled up his elephant-cage design once he heard there are 50-ton mammals out there?
Re (1), I'm dubious that we can design a good general system without quite a few examples which actually need it.
I disagree with this for the following reason: our starting point is that we can do reverse-mode AD in any mutation-free Julia programme built from our primitives / things built from them (some Real
s, Array
s, any struct
built from these, and some basic operations on these (getfield, +, *, getindex etc)) via composition.
Consequently, whenever AD breaks we must have one of the following problems:
+
and *
work that we somehow miss in our existing collection of primitives), orI'm confident that all of the problems we're discussing here are of the 3 or 4 variety. We're not discussing problems 1 (because all of the AbstractArray
s we're talking about are struct
s built from things we understand) or 2. I accept it's possible that I'm missing a problem above, but I'm confident that if I am, it's not something that's affecting us here.
3 and 4 are things we've been struggling with since Zygote was first written, because we never properly sat down and figured out how to map between the data structures that we know how to do AD on (structural tangents), and representations of those data structures that are convenient when writing generic rules (natural tangents). A huge amount of progress has been made on problem 3 in the last few months by yourself, Lyndon, and Miha, in the form of the form of the projection work. However, we still don't properly understand the relationship between structural and natural tangents, which means that we don't have a generic recipe to prevent problems 3 and 4 arising. They will continue to plague us until we develop this understanding, and do something with this understanding.
If we develop a proper understanding of the relationship between structural and natural tangents, we should be able to
It must be possible to do this because we're not trying to write rules for programmes that we don't already know how to differentiate in principle -- we just have technical issues around mutation.
edit: I phrased the first sentence poorly. I don't disagree that we need examples, I just think we've already got enough to try and understand the nature of the problem at hand.
edit2: I explicitly suggested that it matters whether an rrule
is implemented for the sake of performance or to hide mutations in 3 and 4. This was incorrect -- 3 and 4 persist regardless the reason for implementing the rule.
A good example from a thing that came up with some people at Invenia today @mzgubic @thomasgudjonwright :
sqrt(Diagonal(x))
for some positive-valued vector x
. If you take a look at the implementation of this operation in the standard library, you'll see that it's really easy to AD. It's a nice non-mutating implementation, just requiring getfield
, Diagonal
, and broadcast
to AD properly. Indeed, my understanding is that it did this, and produced a Tangent
(structural tangent), as expected.
My understanding is that Tom was writing some code that called the above, but also some other operations on Diagonal
, whose pullbacks happened to return a Diagonal
(natural tangent). These interacted later on in the code, and produced an error. The impact of this was the use of several hours of people's time, trying to figure out what was going on, and how to fix it. The result is that Miha will be opening an ChainRules PR in the near future.
Tom or Miha: please correct me if I've misrepresented any of the above.
For me, this is a clear example of the kind of problem we shouldn't be having. AD did it's job perfectly in producing the result of sqrt(Diagonal(x))
, but the structural tangent that was produced interacted poorly with other tangent representations, causing code to break.
Since the relationship between the structural and natural tangent of a Diagonal
appears to be trivial, this is a really straightforward problem to patch (I believe @mzgubic is planning to make the operation return a natural tangent instead). But it's representative of a larger problem around the relationship between structural and natural tangents, and our (present) inability to gracefully handle interactions between the two, as we've discussed extensively in this issue.
Tom or Miha: please correct me if I've misrepresented any of the above.
Thanks for bringing this up, seems just about right!
One thing to add would be that if Zygote used ChainRules types internally we would not have this issue, since Tangent{Diagonal} and Diagonal can be added to each other. The problem comes when trying to add a NamedTuple to a Diagonal
TIL. Could or should more of these be made to work?
julia> d = Diagonal([1,2,3]);
julia> t = Tangent{typeof(d)}(diag=[4,5,6]);
julia> d + t
3×3 Diagonal{Int64, Vector{Int64}}:
5 ⋅ ⋅
⋅ 7 ⋅
⋅ ⋅ 9
julia> t2 = Tangent{Any}(diag=[4,5,6]);
julia> d + t2 # could guess from d?
ERROR: ArgumentError: type does not have a definite number of fields
julia> u = UpperTriangular([1 2 3; 4 5 6; 7 8 9]);
julia> u + t # could re-construct Diagonal anyway and then try + on naturals?
ERROR: MethodError: no method matching +(::UpperTriangular{Int64, Matrix{Int64}}, ::Tangent{Diagonal ...
To explore these things, I made some examples here: https://github.com/mcabbott/OddArrays.jl
In particular, this has examples of the sort we didn't have above. For r = Rotation(θ)
there is a method for r * r
which returns the same type. This r
doesn't really have a nice "natural" gradient representation. But the generic rule for *
expects one. To make it work I opted out of the rule by hand, but I wonder if that can be automated -- probably no generic rule should accept a Tangent
.
one option is overloading
+(::Tangent{FooArray}, ::AbstractArray})
I think this happens too late. The map from the array to the tangent in general depends on the original array. The easy cases where it doesn't, like Diagonal
, seem to be the cases where we can happily standardise on the "natural" form.
Natural tangents are helpful because they're sometimes more intelligible to humans than structural tangents in a number of situations, and can play nicely with generic linear operations written in
rrule
s in a number of situations, making it more straightforward to write generic rrules. However, this is not always the case, and they can cause complications.The purpose of this issue is to document situations in which structural tangents are either preferable or a necessity, and the complications that can occur when structural and natural tangents interact. Some of it is obvious, and I'm intentionally not proposing to do anything, I just needed to get this stuff out of my head so that I can focus on other work for the next couple of weeks without this bothering me.
Maybe we should link this in the docs that @oxinabox recently merged?
In the following, please assume that the primal is an
AbstractArray
, the structural tangent aTangent
, and the natural tangent anotherAbstractArray
.Interactions between structurals and naturals
Suppose we have code in which a hand-written pullback produces a natural tangent for a primal, and an automatically generated pullback produces a structural tangent for the same primal (say, because it hit
getfield
). At some point, these will need to be added together. This means that code must exist to convert one to the other, and it must be hit. To my knowledge, we don't really deal with this at the minute.Sometimes structurals are more natural
Clarity is in the eye of the beholder, but to my eye it's not uncommon for the structural tangent to be much more straightforwardly interpretable than the natural tangent. Moreover, it's always obvious what the structural tangent is, whereas one has to think (sometimes for an extended period of time, and consult with others) to determine an appropriate natural tangent. This problem is compounded by our current lack of a rigorous definition for the natural tangent.
One example of this is
Fill
arrays. If we think about aFill
as astruct
in the context of AD, it's incredibly clear what's going on and what its structural tangent represents. Conversely, if we think about it as an array, its natural tangent requires a good deal of thought. Again, the lack of a proper definition of a natural tangent is possibly the culprit.Another good example is the
WoodburyPDMat
. I haven't the foggiest idea what an appropriate natural tangent would be, nor do I particularly want to. The structural doesn't suffer this problem. In this sense, it's much more natural to think about the structural tangent. Later in this issue, I'm going to use this as an example of a situation in which using the structural tangent is necessary, and the natural is undefined. If anyone wants to argue that I must derive a natural tangent for this matrix, then we will need to have a separate conversation.Sometimes the natural is simply redundant
On the
WoodburyPDMat
example above, we intentionally only implement a few high-level linear algebra operations. There's simply never any need for a natural tangent because all of the code has been written in an AD-friendly manner (non-mutating, small number of differentiable high-level operations).Other examples are
ColVecs
andRowVecs
in the JuliaGPs ecosystem. They're thin wrappers around anAbstractMatrix
which are really only designed to make its interpretation in a particular context clear. Whilegetindex
is defined, it's considered a bug if it's hit inside AD. Instead, the use ofgetfield
is central toColVecs
andRowVecs
usage. Consequently, AD ought always to be able to derive pullbacks automatically in practice -- certainly we don't want to have to write any rules, or defineProjectTo
.Structural tangents are sometimes a necessity
Symmetric
matrices are a good example of something that can use a natural differential a decent chunk of the time, but sometimes simply cannot.If a
Symmetric
wraps is a primitive, then we'll often want to use the natural differential. For example, this rule forsvd
. This is fine.In other situations, it's easier to think about rules producing structural tangents. For example, Iain Murray's derivation of the rrule for the Cholesky factorisation uses only the upper triangle of a
Symmetric
, so it's easier to think in terms of the structural when writing that rrule.A more extreme example is if one were to put a
WoodburyPDMat
inside a Symmetric and do AD. The tangent of aWoodburyPDMat
will generally be aTangent
, which cannot itself be stored inside of aSymmetric.
Therefore, the tangent of aSymmetric{<:Number, <:WoodburyPDMat}
must be aTangent
.A condition under which a given AbstractArray can be treated as a primitive
By calling a type
T
primitive, I mean that the tangents ofT
are always of typeT
themselves.Assuming that we define the
pullback
ofgetfield
on a givenstruct
to return aTangent
, the above yields the following condition for the possibility of treating a particularAbstractArray
(which is astruct
) as a primitive:Each field must be able to take the type of any of its tangents.
Interestingly, this seems to preclude treating a number of very common AbstractArrays as primitives:
Symmetric
(wraps anAbstractMatrix
, whose tangent may be aTangent
)Diagonal
(for the same reason)Transpose
/Adjoint
(for the same reason)SArray
(wraps aTuple
, whose tangent is always aTangent
)This is, of course, not to say that natural tangents ought not to be used for these types, it is simply to say that it's not possible to preclude the need for the use of a structural tangent some of the time.
Wait, we can't treat
SArray
as a primitive?The above claim was made under the assumption that we make the pullback for the rrule for
getfield
return aTangent
. This is crucial. As @oxinabox pointed out the other day, we don't always have to definegetfield
this way.If for an
SArray
we make said pullback return anotherSArray
, the problem disappears. Indeed, in this case, we can safely treatSArray
s as primitives and never have to worry about the structural derivative. For example, therrule
forgetfield
might be something likeWe could also ensure that all other operations on
SArray
s returnSArray
s.It's not obvious that the same can be done for
Symmetric
/Diagonal
/Transpose
/Adjoint
etc. The thing that made it work forSArray
was the ability to meaningfully convert the structural tangent of its field (aTangent
) into the type of its field (aTuple
).Ultimately, this seems to come down to how we define
getfield
's rrule. If we are willing to put the work in to ensure that any tangent for a field can be converted into a thing that can be put back into the primitive (and still have sensible semantics, like+
andReal *
working) then we might be able to treat a given type as primitive. This simply isn't possible for lots of arrays because, as discussed above, we can't generally rule out the need for a structural tangent, which rules outSymmetric
etc.This doesn't seem to have tied down exactly what we can / can't treat as a primitive, but I thnk it gets us closer, and it at least suggests that we can ensure the non-existance of structural tangents for
SArray
s if someone is willing to put the work in (and figure out where the code should live!)