JuliaDiff / ChainRulesCore.jl

AD-backend agnostic system defining custom forward and reverse mode rules. This is the light weight core to allow you to define rules for your functions in your packages, without depending on any particular AD system.
Other
258 stars 62 forks source link

Structural tangents are cool too, and SArrays can probably be primitives #441

Open willtebbutt opened 3 years ago

willtebbutt commented 3 years ago

Natural tangents are helpful because they're sometimes more intelligible to humans than structural tangents in a number of situations, and can play nicely with generic linear operations written in rrules in a number of situations, making it more straightforward to write generic rrules. However, this is not always the case, and they can cause complications.

The purpose of this issue is to document situations in which structural tangents are either preferable or a necessity, and the complications that can occur when structural and natural tangents interact. Some of it is obvious, and I'm intentionally not proposing to do anything, I just needed to get this stuff out of my head so that I can focus on other work for the next couple of weeks without this bothering me.

Maybe we should link this in the docs that @oxinabox recently merged?

In the following, please assume that the primal is an AbstractArray, the structural tangent a Tangent, and the natural tangent another AbstractArray.

Interactions between structurals and naturals

Suppose we have code in which a hand-written pullback produces a natural tangent for a primal, and an automatically generated pullback produces a structural tangent for the same primal (say, because it hit getfield). At some point, these will need to be added together. This means that code must exist to convert one to the other, and it must be hit. To my knowledge, we don't really deal with this at the minute.

Sometimes structurals are more natural

Clarity is in the eye of the beholder, but to my eye it's not uncommon for the structural tangent to be much more straightforwardly interpretable than the natural tangent. Moreover, it's always obvious what the structural tangent is, whereas one has to think (sometimes for an extended period of time, and consult with others) to determine an appropriate natural tangent. This problem is compounded by our current lack of a rigorous definition for the natural tangent.

One example of this is Fill arrays. If we think about a Fill as a struct in the context of AD, it's incredibly clear what's going on and what its structural tangent represents. Conversely, if we think about it as an array, its natural tangent requires a good deal of thought. Again, the lack of a proper definition of a natural tangent is possibly the culprit.

Another good example is the WoodburyPDMat. I haven't the foggiest idea what an appropriate natural tangent would be, nor do I particularly want to. The structural doesn't suffer this problem. In this sense, it's much more natural to think about the structural tangent. Later in this issue, I'm going to use this as an example of a situation in which using the structural tangent is necessary, and the natural is undefined. If anyone wants to argue that I must derive a natural tangent for this matrix, then we will need to have a separate conversation.

Sometimes the natural is simply redundant

On the WoodburyPDMat example above, we intentionally only implement a few high-level linear algebra operations. There's simply never any need for a natural tangent because all of the code has been written in an AD-friendly manner (non-mutating, small number of differentiable high-level operations).

Other examples are ColVecs and RowVecs in the JuliaGPs ecosystem. They're thin wrappers around an AbstractMatrix which are really only designed to make its interpretation in a particular context clear. While getindex is defined, it's considered a bug if it's hit inside AD. Instead, the use of getfield is central to ColVecs and RowVecs usage. Consequently, AD ought always to be able to derive pullbacks automatically in practice -- certainly we don't want to have to write any rules, or define ProjectTo.

Structural tangents are sometimes a necessity

Symmetric matrices are a good example of something that can use a natural differential a decent chunk of the time, but sometimes simply cannot.

If a Symmetric wraps is a primitive, then we'll often want to use the natural differential. For example, this rule for svd. This is fine.

In other situations, it's easier to think about rules producing structural tangents. For example, Iain Murray's derivation of the rrule for the Cholesky factorisation uses only the upper triangle of a Symmetric, so it's easier to think in terms of the structural when writing that rrule.

A more extreme example is if one were to put a WoodburyPDMat inside a Symmetric and do AD. The tangent of a WoodburyPDMat will generally be a Tangent, which cannot itself be stored inside of a Symmetric. Therefore, the tangent of a Symmetric{<:Number, <:WoodburyPDMat} must be a Tangent.

A condition under which a given AbstractArray can be treated as a primitive

By calling a type T primitive, I mean that the tangents of T are always of type T themselves.

Assuming that we define the pullback of getfield on a given struct to return a Tangent, the above yields the following condition for the possibility of treating a particular AbstractArray (which is a struct) as a primitive:

Each field must be able to take the type of any of its tangents.

Interestingly, this seems to preclude treating a number of very common AbstractArrays as primitives:

This is, of course, not to say that natural tangents ought not to be used for these types, it is simply to say that it's not possible to preclude the need for the use of a structural tangent some of the time.

Wait, we can't treat SArray as a primitive?

The above claim was made under the assumption that we make the pullback for the rrule for getfield return a Tangent. This is crucial. As @oxinabox pointed out the other day, we don't always have to define getfield this way.

If for an SArray we make said pullback return another SArray, the problem disappears. Indeed, in this case, we can safely treat SArrays as primitives and never have to worry about the structural derivative. For example, the rrule for getfield might be something like

function rrule(::typeof(getfield), X::SArray, ::Any) # there's only one field, so nothing else to access
    getfield_pullback(TY::Tangent) = SArray(TY...)
    return S.data, getfield_pullback # S.data is a tuple, so its tangent must be a Tangent
end

We could also ensure that all other operations on SArrays return SArrays.

It's not obvious that the same can be done for Symmetric / Diagonal / Transpose / Adjoint etc. The thing that made it work for SArray was the ability to meaningfully convert the structural tangent of its field (a Tangent) into the type of its field (a Tuple).

Ultimately, this seems to come down to how we define getfield's rrule. If we are willing to put the work in to ensure that any tangent for a field can be converted into a thing that can be put back into the primitive (and still have sensible semantics, like + and Real * working) then we might be able to treat a given type as primitive. This simply isn't possible for lots of arrays because, as discussed above, we can't generally rule out the need for a structural tangent, which rules out Symmetric etc.

This doesn't seem to have tied down exactly what we can / can't treat as a primitive, but I thnk it gets us closer, and it at least suggests that we can ensure the non-existance of structural tangents for SArrays if someone is willing to put the work in (and figure out where the code should live!)

mcabbott commented 3 years ago

It's pretty easy to make Zygote give you errors from natural + structural, for instance:

julia> gradient(x -> (real(x) + x.im), 1+im)
ERROR: MethodError: no method matching +(::NamedTuple{(:re, :im), Tuple{Nothing, Int64}}, ::Int64)

julia> gradient(x -> (sum(x) + last(x)), 1:3)
ERROR: MethodError: no method matching +(::NamedTuple{(:start, :stop), Tuple{Nothing, Int64}}, ::FillArrays.Fill{Int64, 1, Tuple{Base.OneTo{Int64}}})

These don't seem to be very common in the wild? In these two cases, I think the rule for getproperty should call a ProjectTo which standardises on the natural. But in general, it seems tricky.

willtebbutt commented 3 years ago

These don't seem to be very common in the wild?

Not super common, but I don't think I've worked on a project where I've not encountered them (for example, if you do enough things with a Diagonal, you will probably encounter at some point), and I usually wound up type-pirating to get around them.

In these two cases, I think the rule for getproperty should call a ProjectTo which standardises on the natural.

Again, not completely sure that I agree re. Fills and naturals for the reasons laid out above (the natural for a Fill is deeply unintuitive to me, whereas the structural is simple). For example, could / should @no_opt most things in FillArrays, including sum, and just get the structural.

edit: I meant @opt_out, not @no_opt.

oxinabox commented 3 years ago

It's pretty easy to make Zygote give you errors from natural + structural

For these cases one option is overloading +(::Tangent{FooArray}, ::AbstractArray}) which could then do either a projection onto the natural, or a structuralization onto the Tangent. Or potentially even a form of thinking the just delays til we go to accumulate against a primal (Keno was talking about that a whole ago)

Possibly this should be a seperate issue

willtebbutt commented 3 years ago

For these cases one option is overloading +(::Tangent{FooArray}, ::AbstractArray})

Certainly -- this kind of code has to be written if we're ever allowing naturals and structurals to interact. I think there's a whole separate issue to be had about the points in the code at which we want the structural and at which points we want the natural.

For example, I think it's reasonable to assert that we often want the natural inside rules, as evidenced by people's experience over the last few months (not always though). Conversely, constructors like structurals, so it seems possible that we want to ensure that naturals don't escape rules, or something like that.

Possibly this should be a seperate issue

Yes.

mcabbott commented 3 years ago

In these two cases,

not completely sure that I agree re. Fills and naturals for the reasons laid out above (the natural for a Fill

Then you're looking at different examples, or talking about generalities. My specific two examples do not contain any tangent for an input x::Fill, natural or otherwise.

But if you don't find the natural for x::Fill intuitive, the natural gradient for a range will be worse.

However, the natural gradient for x::Complex is not so bad. So there are some structs for which we always want to normalise to normal. Notice also that in my example, it does not hit +(::Tangent{Foo}, ::Foo} -- the natural tangent from the other path lies in a subspace which isn't a subtype.

I think that's clearly what we want for SArrays too. Handling a Tangent{SVector}(; data=Tuple(...)) seems strictly inferior. In fact accessing its fields directly (in the forward code) seems extremely odd here, these are arrays. (Although obviously people do odd things in the wild & we should try to build guardrails, where possible.)

willtebbutt commented 3 years ago

My specific two examples do not contain any tangent for an input x::Fill, natural or otherwise.

Wow, sorry, I completely misread your example. You're right that it's painful working with these types. This is why I prefer the structural tangent for them, although I also agree that it's not always what you need.

I think that's clearly what we want for SArrays too. Handling a Tangent{SVector}(; data=Tuple(...)) seems strictly inferior.

I agree in the specific case of SArrays that we probably want them to be primitive.

In fact accessing its fields directly (in the forward code) seems extremely odd here, these are arrays. (Although obviously people do odd things in the wild & we should try to build guardrails, where possible.)

This I disagree with, in the general case. Array authors access fields all the time when writing methods optimised for their arrays, Fill arrays being a good example (getindex is essentially just getfield here). It's simply not odd to want to access fields of a structured array.

mcabbott commented 3 years ago

Again, that paragraph you are quoting from is specifically about SArrays, not generalities. I'm not sure I have ever seen anyone access the field by name; I had no idea what its name was. They are used as arrays in user code, and their interface with the outside world is through the usual AbstractArray interface. What the package authors do behind closed doors is, of course, nobody else's business.

I agree that what to do with other types is more tricky. Types that straddle the border between being containers and being arrays, I guess. Although I suppose I'd still like better examples of how these get used, to ground discussion. For instance field access is not really part of the API of Diagonal, and even less so for Symmetric (since it gets you meaningless data) but perhaps it is for some other weird type.

willtebbutt commented 3 years ago

Again, that paragraph you are quoting from is specifically about SArrays, not generalities. I'm not sure I have ever seen anyone access the field by name; I had no idea what its name was. They are used as arrays in user code, and their interface with the outside world is through the usual AbstractArray interface. What the package authors do behind closed doors is, of course, nobody else's business.

Cool. Apologies for generalising! Yeah, we're definitely in agreement that users probably shouldn't be accessing SArrays via getfield, that just seems weird. That being said, if we implement getfield properly, it should just be fine 🤷 (if ill-advised).

Although I suppose I'd still like better examples of how these get used, to ground discussion.

Well the ColVecs / RowVecs / Fill / WoodburyPDMat are all examples intended for this purpose.

ColVecs / RowVecs / WoodburyPDMat are of the "don't ever give me a natural tangent" flavour -- they're entirely written with a small set of AD-friendly operations in mind, so we have no need to write rrules for them, except perhaps to @non_differentiable some things here and there, and @opt_out of any generic rules that we don't like (although I don't think we've encountered any yet fortunately -- might change as people start writing more generically-typed rules though).

Fill arrays are a bit different -- pretty much every function they provide specialised implementations for should produce a structural tangent just fine without writing any rrules (e.g. sum, broadcast, map etc), but there are probably others that don't have specialised implementations and hit generic fallbacks, and therefore need the fallback rules + projection mechanism. So Fill is a good example where we should expect a mixture of different tangent types to appear during the reverse-pass of AD.

Are these examples missing something that you're interested in seeing?

For instance field access is not really part of the API of Diagonal, and even less so for Symmetric (since it gets you meaningless data) but perhaps it is for some other weird type.

I agree that field access is typically not part of the user-facing API of an array (although I also agree that it might be in some weird instances). However, AD ought to just work on functions written by library authors if they write code with AD in mind (library authors shouldn't have to write rrules if they write their code in an AD-friendly manner).

That being said, if they also want their types to work with the generic fallbacks defined for AbstractArrays, they'll need to write ProjectTo, and functionality to convert natural tangents into structural tangents so that one can AD through the constructors (unless they make ProjectTo return a structural, in which case the natural doesn't exist. Not sure what kinds of problems that can cause though -- I imagine it preceludes the use of some generic rrules that they might be interested in. As you say though, examples of this would help the discussion).

As regards people accessing internal fields of types when they probably shouldn't be, I agree that they probably shouldn't be. However, we do know how to AD through ill-advised field access, and we know how to represent the tangent. To my mind, from an AD's perspective, there's not really a distinction between approved field access by type authors, and ill-advised field access by misguided users.

Taking the Diagonal example, if someone writes

foo(D::Diagonal) = sum(D.diag)

we can AD it with no problem provided that the rule system doesn't somehow get in the way. If the standard library changes the internals of Diagonal and breaks a user's code, then of course the user is entirely to blame.

mcabbott commented 3 years ago

examples of how these get used

Are these examples missing something

Yes, thanks for the types, but what I meant was whether there are closer to end-to-end examples which give errors / wrong results / are inefficient. If ColVecs / RowVecs are glorified NamedTuples, which are always treated as structs, always Tangent, then there is no problem right?

The tricky cases seem to be (1) where some generic rule passes a special type along, and hence its generic pullback probably expects an AbstractArray, and (2) where the same object is used twice, and one path treats it as an array, the other as a struct, hence + fails. My explicit examples were of (2).

The simplest case I can picture for (1) is this, but I don't know why it isn't an error:

julia> gradient((x,y) -> (x*y).re, 1+im, 2)
((re = 2, im = nothing), (re = 1 - 1im, im = nothing))

Of course, this one should be avoided in multiple ways: You should write real() not .re to be safe. The getfield of Complex should not construct a Tangent/NamedTuple, since Complex is clearly simple enough to always be its own gradient; Maybe that needs a special overload as you suggest for SArray, or maybe the rule for getproperty just calls ProjectTo which handles it.

But there might be others which are less easy to avoid -- what are the simplest such? For Fill I can invent things if I use field access. But, the entire point of Fill is that it's a cheap thing to pass into generic array code which does not know that it's this type... if you knew this all the way through, you would use a number... so this seems like cheating:

julia> gradient(x -> sum(x .+ 10), Fill(3,3))
(3-element Fill{Int64}: entries equal to 1,)

julia> gradient(x -> (x .+ 10).value, Fill(3,3))  # generic unbroadcast does not handle Tangent/NamedTuple
ERROR: MethodError: no method matching size(::NamedTuple{(:value, :axes), Tuple{Int64, Nothing}})
Stacktrace:
  [1] unbroadcast(x::Fill{Int64, 1, Tuple{Base.OneTo{Int64}}}, x̄::NamedTuple{(:value, :axes), Tuple{Int64, Nothing}})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/broadcast.jl:51
willtebbutt commented 3 years ago

If ColVecs / RowVecs are glorified NamedTuples, which are always treated as structs, always Tangent, then there is no problem right?

You're correct. Again, lack of clarity on my part -- with these kinds of things I was just trying to assemble some concrete examples of things that do indeed work just fine, so that I can point to them if anyone is ever confused about the need for structural tangents for AbstractArrays.

The tricky cases seem to be (1) where some generic rule passes a special type along, and hence its generic pullback probably expects an AbstractArray, and (2) where the same object is used twice, and one path treats it as an array, the other as a struct, hence + fails. My explicit examples were of (2).

Agreed. The options for resolving case 2 are clearer to me than 1. Either, you must (I think)

I quite like the second approach because, unless I'm mistaken, it would make explicit what I think we're currently doing implicitly in a bunch of rrules for constructors (conversion from natural to structural), for example Diagonal.

I'm struggling to come up with good examples for (1). Any thoughts?

But, the entire point of Fill is that it's a cheap thing to pass into generic array code which does not know that it's this type... if you knew this all the way through, you would use a number...

This not always the case. My favourite example of this is in AbstractGPs. The third field of our FiniteGP type is always a covariance matrix (I'm not sure why we don't make it subtype AbstractMatrix now that I read the code again). Very often, it will be either

Diagonal(some_Vector_of_Reals)
Diagonal(Fill(s, N))

for some real number s and integer N, and some_Vector_of_Reals<:AbstractVector{<:Real}. A thing we rely on is specialising based on whether or not this matrix is Diagonal or a Diagonal{<:Real, <:Fill}, beause there are often entirely different algorithms that we can use in those cases. TemporalGPs relies on this entirely (because efficient inference isn't possible this is matrix isn't diagonal), and OILMMs.jl (which is currently being reworked into a more general package) only works when this is a Diagonal{<:Real, <:Fill}.

We need it to be an AbstractMatrix{<:Real} so that we can be sure it can be added to another matrix of the same size (wouldn't work for scalars or vectors). Generally speaking, if we hit this code path we couldn't care less what kind of matrix it is.

So I think the way in which this differs from what you're presuming is that some of the time we couldn't care less whether we have a Fill or not, but other bits of the time we really do.

But there might be others which are less easy to avoid -- what are the simplest such?

I think the Fill example is a good one. It can be made to feel a bit more realistic by opting out of sum by defining my_sum (if I'm not mistaken Zygote defines a rule for sum). Something like

julia> using FillArrays

julia> using Zygote

julia> using FillArrays: AbstractFill

julia> # Work around a Zygote performance bug by using Zygote.literal_getfield manually.
       my_sum(x::AbstractFill) = Zygote.literal_getfield(x, Val(:value)) * length(x)
my_sum (generic function with 1 method)

julia> Zygote.gradient(my_sum, Fill(5.0, 10))
((value = 10.0, axes = nothing),)

julia> Zygote.gradient(x -> my_sum(x .+ 10), Fill(5.0, 10))
ERROR: MethodError: no method matching size(::NamedTuple{(:value, :axes), Tuple{Float64, Nothing}})
mcabbott commented 3 years ago

we insist that natural never escape pullbacks

I'm not sure what this means exactly. By "natural tangent" we mean something returned by a pullback which isa AbstractArray, no? (Or isa Number) What happens inside the function we don't care.

What I'm hoping to elucidate with examples is whether any types actually have to straddle the border between array-like and struct-like. I don't see why we shouldn't treat Fill as array-like, 100% of the time. Like Complex should be Numer-like, 100% of the time. Conversion is free. Whereas it sounds like ColVecs / RowVecs are like NamedTuples, structural 100% of the time --- maybe they are what you mean by "never escapes"? Can we partition all types into one category or the other?

willtebbutt commented 3 years ago

I'm not sure what this means exactly. By "natural tangent" we mean something returned by a pullback which isa AbstractArray, no? (Or isa Number)

Yeah, more or less. My working definition is (excluding primals like Array and Float64 etc) that a natural tangent (/ differential, I've stopped calling them differentials) is anything that's not a Tangent, or one of the other tangent types in CRC (ZeroTangent, NoTangent, etc).

What happens inside the function we don't care.

Definitely -- modulo it returning acceptable types.

What I'm hoping to elucidate with examples is whether any types actually have to straddle the border between array-like and struct-like.

Symmetric / Diagonal have to straddle, because they can wrap arbitrary AbstractMatrixs / AbstractVectors respectively, and said AbstractMatrix / AbstractVector might be a container-like AbstractArray, meaning that it only has a structural tangent, which can't be put inside a Symmetric / Diagonal.

e.g. the tangent of a

Symmetric{<:Real, <:WoodburyPDMat}

would have to be a Tangent, because the tangent of a WoodburyPDMat must be a Tangent.

I don't see why we shouldn't treat Fill as array-like, 100% of the time.

This is a fair point -- I also can't see any reason, beyond my lingering doubts about the properties of using it as its own tangent (I'm concerned that sometimes it needs to behave like a struct and other times like an array, and maybe that leads to inconsistencies. I'm planning on trying to convince myself one way or the other over the next couple of days).

maybe they are what you mean by "never escapes"?

By "never escapes", I mean that a pullback always returns a Tangent (or ZeroTangent etc).

My understanding is that natural tangents are primarily a way to

  1. make generic rrules possible to write (which is great, this is a massive win), and
  2. produce an interpretable tangent.

This suggests to me that they're important in exactly two places:

  1. inside rrules, and
  2. outside of AD (e.g. to make the result of AD more interpretable if a structural tangent is returned and for some reason it's helpful to have in natural form. I'm sure examples exist here, but I've actually yet to see one)

Outside of these two situations, it's potentially a different matter. For example, inside the AD pipeline, we could require all tangents to be structural (or ZeroTangent etc), which would then be simple to deal with (+ always well defined, default constructor rrules happy). We could make this easy for users to do by requiring project_to to instead return the structural tangent (presumably a breaking change, not possible in 1.x, I assume?), and providing another function to do the structural -> natural mapping (Tangent to Symmetric, for example, when it makes sense to do, and erroring if the particular Tangent doesn't have a corresponding natural tangent). This could be done at the start of a pullback definition, if the user would rather have e.g. an AbstractArray than a Tangent (could also unthunk automatically etc).

My suspicion is that this kind of change would be minimally invasive, and would have the benfit of localising errors to the inside of rrules (instead of not being able to +, conversion to / from natural tangent might fail. The former case tends to happen in the middle of an AD pipeline, while the latter should happen inside an rrule that a human wrote. Consequently, the latter is probably easier to understand and debug). It would also make the rule-writer's expectations about the types they're going to get be really explicit. By asking for the natural tangent, you're saying "I want something array-like, on which I can perform the usual array operations" (in the context of tangents for AbstractArrays).

For example, the rule for *(::AbstractVecOrMat, ::AbstractVecOrMat) would become something like


function rrule(
    ::typeof(*),
    A::AbstractVecOrMat{<:CommutativeMulNumber},
    B::AbstractVecOrMat{<:CommutativeMulNumber},
)
    Y = A * B
    project_Y = ProjectToNatural(Y)
    project_A = ProjectTo(A)
    project_B = ProjectTo(B)
    function times_pullback(ȳ)
        Ȳ = project_Y(ȳ)
        dA = @thunk(project_A(Ȳ * B'))
        dB = @thunk(project_B(A' * Ȳ))
        return NoTangent(), dA, dB
    end
    return Y, times_pullback
end

edit: re FillArrays, if you have an operation which AD successfully derives, it'll presumably expect a Tangent as the tangent for its output on the reverse-pass. Trying to find a good example now.

edit2: no, with getfield implemented appropriately, this should be fine.

edit3: break up long sentence.

mcabbott commented 3 years ago

By "never escapes", I mean that a pullback always returns a Tangent

You mean every pullback always returns a Tangent, I think. (Or maybe not for Complex.) I'm going to dub this the nuclear option. I think it means that arrays unknown to the AD system can never propagate through AD, as arrays. They all need special reconstructor functions before they can be used within generic rules. It can't "just work" with SArrays & OffsetArrays.

The "not straddling" idea is I think to tie the behaviour to type of x. Perhaps, for some types, all pullbacks always return a Tangent, while for others, no pullback ever returns such. The type of x::Symmetric contains the type of parent(x), maybe that counts. If every x is on one side of the fence or the other, then my problem (2) never occurs, + is never confused.

I still think we need better examples of problem (1). That is, sufficiently weird array types not to have easy natural tangents, which are nevertheless the result of the forward pass of some generic function, whose generic rrule we might hope to use. If the construction of a natural tangent is possible but expensive, then you can see some virtue to delaying it until certainly necessary. If the construction is impossible, then you cannot use the generic rrule under any scheme, so we're not in problem (1) anymore. (You will either need to write another, or opt-out.)

willtebbutt commented 3 years ago

I'm going to dub this the nuclear option.

Haha fair enough. I would prefer it to be characterised as the "clean" option, in that it cleanly separates the world of the users / rule-implementers from the world of the AD.

I think it means that arrays unknown to the AD system can never propagate through AD, as arrays.

Is this a problem? AD systems like structural tangents -- that's their "native" representation of any struct, and by extension their native representation of any AbstractArray that isn't a primitive.

They all need special reconstructor functions before they can be used within generic rules.

I'm not convinced we don't already have to do this somewhere, it's just unclear where. For example, we already have to write code to map from the natural to the structural inside the a hand-written rrule for the constructor.

How is it that an AbstractArray tangent comes into existence in the first place if not via conversion from a structural tangent? (structural tangents being the only thing that can be generated automatically).

The "not straddling" idea is I think to tie the behaviour to type of x. Perhaps, for some types, all pullbacks always return a Tangent, while for others, no pullback ever returns such. The type of x::Symmetric contains the type of parent(x), maybe that counts. If every x is on one side of the fence or the other, then my problem (2) never occurs, + is never confused.

Sorry, I'm really struggling to follow this para. Could you elaborate?

I still think we need better examples of problem (1).

Is the Fill example not sufficient? It has the key properties that we care about, does it not?

edit: ignore my last comment. Reading your para again.

mcabbott commented 3 years ago

How is it that an AbstractArray tangent comes into existence in the first place if not via conversion from a structural

The generic rule for sum(x) would be something like dx = similar(x).=1. Which will produce a CuArray, or a StaticArray, or an OffsetArray, etc. Many are like this.

we already have to write code to map from the natural to the structural inside the a hand-written rrule for the constructor

Yes, usually, this is why my (2) is the easy problem. But the reconstructor is the map in the other direction.

Fill's "reconstructor" is trivial and free. The simplest Symmetric's is easy but not necessarily cheap. I am not sure whether this is possible at all for some of the weird wrapper combinations you mention. But I also don't know whether those can plausibly be produced by a generic rule as specified. If they cannot, then problem (1) does not occur.

Problem (2) is about two tangents for the same x; if dx-always-natural vs. dx-always-structural was determined by typeof(x) (and enforced somehow, by projectors) then it is almost a tautology that they will always agree. How easily this could be implemented would need more thought.

willtebbutt commented 3 years ago

Addressing your previous comment:

sufficiently weird array types not to have easy natural tangents

If we try to apply a generic rrule to a type which doesn't have a natural tangent defined, then I agree that you couldn't ever hope to use generic rrules, and that's fine as far as I'm concerned.

To my mind the more troublesome case is that in which an AbstractArray could make use of a generic rrule, but in our current set up is unable to because at no point in our rrules do we say "give me the natural tangent". As things stand, I think the problem will arrise whenever we have the following structure:

foo(x::AbstractArray) = _hard_code_to_differentiate_so_has_generic_rrule
bar(x::Fill) = easy_to_differentiate_code_without_rrule

Forwards pass:

a = Fill(5.0, 2)
b = foo(a) # b isa Fill
c = bar(b)

Reverse-pass:

Db = pullback_bar(Dc) # Db isa Tangent
Da = pullback_foo(Db) # hits a generic rrule and breaks

Of course, you could replace Fill with anything that's not a primitive.

A concrete example from the stdlib:

using LinearAlgebra
using Zygote

foo(x) = 5x
bar = parent

b, pullback_foo = Zygote.pullback(foo, Symmetric(randn(5, 5)))
c, pullback_bar = Zygote.pullback(bar, b)

Dc = randn(5, 5)
Db = pullback_bar(Dc)[1]
Db isa NamedTuple # structural tangent
pullback_foo(Db) # errors when it hits https://github.com/JuliaDiff/ChainRules.jl/blob/555ac11cf25c03e9dbb8dcd9d431bc8ff11349ee/src/rulesets/Base/arraymath.jl#L93

Uses Zygote types, but the point is clear. You could replace foo with any of the matrix functions whose rrules are defined here, ^, or inv, and get the same kind of problem.

This example also works with parent and Diagonal, when foo is:

I tried the same thing for parent and Adjoint, and got an Adjoint back for Db. This means that someone has written a rule for this method. It turns out that it's in Zygote, so probably Mike or I wrote it at some point long ago. This rule will break if anyone ever wants the Adjoint of a container-like AbstractArray that only has a structural tangent. Even if this weren't the case, we just shouldn't need to write a rule for such a simple function.

Addressing your most recent comment:

The generic rule for sum(x) would be something like dx = similar(x).=1.

I see -- this seems like a reasonable heuristic. I guess we have to be careful about doing things like assuming mutability though e.g. in the StaticArrays case.

But the reconstructor is the map in the other direction.

Indeed. I was just pointing out that the map is already written in that direction. My guess would be that writing it in the other direction given that code will be very straightforward. It also appears necessary to make scenario (1) (eg. the Symmetric example above) work in general.

The simplest Symmetric's is easy but not necessarily cheap

Agreed.

am not sure whether this is possible at all for some of the weird wrapper combinations you mention. But I also don't know whether those can plausibly be produced by a generic rule as specified. If they cannot, then problem (1) does not occur.

Yeah -- it's not going to be possible for AbstractArrays without a natural-structural mapping to participate in the generic rrules, and that seems acceptable to me.

It's definitely plausible that such an array could be output from an operation with a generic rrule. It would happen if the array's author implemented a generic operation (e.g. *(::Number, ::AbstractArray)) in a non-differentiable manner (if they'd implemented it in a differentiable manner, they should @opt_out), but which outputs their type.

I agree it seems hard to see how it could happen if the array author relies on fallback code somewhere.

willtebbutt commented 3 years ago

I'm thinking again about my proposal from a few comments ago, and how the minimal number of translations between natural and structural tangents could be achieved.

Consider again *:

function rrule(
    ::typeof(*),
    A::AbstractVecOrMat{<:CommutativeMulNumber},
    B::AbstractVecOrMat{<:CommutativeMulNumber},
)
    Y = A * B
    project_Y = ProjectToNatural(Y)
    project_A = ProjectTo(A)
    project_B = ProjectTo(B)
    function times_pullback(ȳ)
        Ȳ = project_Y(ȳ)
        dA = @thunk(project_A(Ȳ * B'))
        dB = @thunk(project_B(A' * Ȳ))
        return NoTangent(), dA, dB
    end
    return Y, times_pullback
end

Let's assume that we allow naturals and structurals to float around anywhere they like, so not the non-escaping situation I discussed.

Further, let's keep ProjectTos semantics the same (yay, no breaking change).

Further assume that we have access to a function get_structural(tangent) which is the identity if the argument is a Tangent, produces the structural if it's a natural tangent (I'm not sure such a function could be implemented precisely like this, but bear with for now).

You could then easily define + between different tangent representations:

+(t::Tangent, s::Any) = t + get_structural(s)
+(t::Any, s::Tangent) = get_structural(t) + s

Natural tangents have to have + defined, so the natural + natural case isn't a problem.

In the above scheme, conversions between natural and structural would only happen if an rrule author explicitly requests the structural, or if a structural and natural interact, which seems like it ought to achieve the minimum possible number of conversions.

I think you'd just need to put a call to get_structural inside Zygote's / Diffractor's constructor rules (since constructors don't know how to handle naturals), and you'd be grand (everything would work).

edit: my only concern with this scheme is the plausibility of implementing get_structural in the absence of access to the primal. I'm pretty sure that you'd need to know which primal you're dealing with in order to implement it, for the same reasons that you do in order to implement ProjectTo.

willtebbutt commented 3 years ago

One other point: the fact that situation (1) occurs means that we need to write structural -> natural mappings anyway. I don't think there's any escaping it. At the moment we just have code that we can't AD 🤷

mcabbott commented 3 years ago

This is a lot of text. I think that all examples with Fill are solved by

(project::ProjectTo{Fill})(dx::Tangent) = Fill(dx.value / prod(length, project.axes), project.axes)

matching what already happens for Complex

julia> ProjectTo(1+im)(Tangent{Complex}(; re=33))
33.0 + 0.0im

plus having the gradient of getproperty apply projectors.

The example for (1) so far is

gradient(abs ∘ first ∘ parent ∘ log, Symmetric(rand(3,3)))  # error, adjoint(::Tangent)

in which log(from your list, less trivial than 5*) is the function with a generic rule, which propagates Symmetric.

One objection here is that parent is just another way of writing getproperty. It's only safe to call if you know precisely the type already -- a weird mix of generic and specific. (And first keeps an element we know not to be junk, sum would not. Which means parent is doing nothing.)

For this, it sounds like the projector for Symmetric should reconstruct from tangents. Maybe that's more expensive than ideal, but cheap compared to log.

Zygote already does this for Adjoint, where it's free. But:

will break if anyone ever wants the Adjoint of a container-like AbstractArray that only has a structural tangent

Maybe. The real problem (1) is when this combination is also produced by the forward pass of a generic rule. That's what I think we still lack examples of.

mcabbott commented 3 years ago

For problem (2), I'm not sure that calling to_structural on the mismatched natural is always going to work. We allow (some) subspaces which aren't subtypes. If you want to add Tangent{UpperTriangular}(...) + Diagonal(...), turning the 2nd into Tangent{Diagonal}(...) won't help. Although I don't have a good example that would produce this.

willtebbutt commented 3 years ago

Maybe. The real problem (1) is when this combination is also produced by the forward pass of a generic rule. That's what I think we still lack examples of.

I stand by the above examples, and believe they are sufficient. To my mind it is enough that someone could do this, and there's no reasonable grounds on which we could object to someone doing this, in order for us to need to cover it. I'm happy to produce a MWE, but I don't believe there's a need to go hunting in the wild.

I agree that a line has to be drawn somewhere (e.g. we can agree that strings are non-differentiable, even though though someone could in principle write a differentiable programme in which data passes through them), but I can find no such objection in terms of the existing examples.

One objection here is that parent is just another way of writing getproperty. It's only safe to call if you know precisely the type already -- a weird mix of generic and specific.

This example is analogous to the Fill example I gave before. Sometimes you want to specialise, other times you need to be generic, and they could be in the same programme.

This is besides the point though. The example in question is perfectly legal differentiable Julia code, and it doesn't involve mutation (because we have a rule for log). To my mind that is sufficient, but it's also part of the public API of a standard library. We have to cover this kind of thing by default, otherwise we get in the way of AD doing AD.

This is a lot of text. I think that all examples with Fill are solved by

I think this is correct, in this particular case, because Fill could plausibly be a primitive (I think. Again, given our lack of definitions and agreed-upon consistency criteria, I'm not going to say that I fully understand this case, and therefore I'm not 100% certain that this is a viable solution).

I wonder whether this is desirable though (making things primitives in general).

Suppose that the set of things that we ultimately decide that we need to do in order to support structural-to-natural mapping / projection solve all of the problems we're encountering, and work for all AbstractArrays. In that case, we could just do those things for Fills, SArrays etc, and not bother making them primitives. As far as I can see, we wouldn't lose anything by doing so, but we would gain simplicity -- we could provide a single clear set of instructions to array authors that want their types to play nicely with generic rrules, rather than giving them multiple options. This would also make our lives easier (fewer things to maintain / test / understand / define).

Phrased differently, making things primitives doesn't solve the wider problem. It only works in a subset of the cases that we care about, and we need to solve the more general problem anyway. If the more general solution isn't any more complicated than the specific solution, we only need the more general one.

(I realise this is back-tracking a bit on my earlier position regarding making SArray a primitive, but I can't think of any compelling advantages to doing so if we solve the problem more generally).

willtebbutt commented 3 years ago

If you want to add Tangent{UpperTriangular}(...) + Diagonal(...), turning the 2nd into Tangent{Diagonal}(...) won't help.

That's an interesting example. I've only implemented a single to_structural method for each primal type, but I definitely believe that you could have multiple valid naturals. So multiple methods of to_structural are probably required for each primal. For example

to_structural(b::Bijection{<:UpperTriangular}, n::Diagonal) = ...

seems perfectly reasonable to me.

That being said, say you got a Diagonal natural for an UpperTriangular as you suggest. What should ProjectTo do to the Diagonal? Should it produce an UpperTriangular, or leave it as a Diagonal?

mcabbott commented 3 years ago

What should ProjectTo do

Right now it explicitly allows this, but I'm not sure much is gained: https://github.com/JuliaDiff/ChainRulesCore.jl/blob/master/src/projection.jl#L330-L337 Without that, you'd get UpperTriangular(Diagonal(...)).

A sketch of a larger machine for such things: #390

Proposal to return Diagonal(Fill(...)) from tr: https://github.com/JuliaDiff/ChainRules.jl/issues/46#issuecomment-620707731 . The projector for a full matrix will certainly allow a Diagonal gradient through.

mcabbott commented 3 years ago

are sufficient. To my mind it is enough that someone could do this, [...] don't believe there's a need to go hunting in the wild.

Re (1), I'm dubious that we can design a good general system without quite a few examples which actually need it. Maybe they have other weird properties which neither of us have thought of, nature is more creative than us. Heard about the zoo-keeper who scaled up his elephant-cage design once he heard there are 50-ton mammals out there?

willtebbutt commented 3 years ago

Re (1), I'm dubious that we can design a good general system without quite a few examples which actually need it.

I disagree with this for the following reason: our starting point is that we can do reverse-mode AD in any mutation-free Julia programme built from our primitives / things built from them (some Reals, Arrays, any struct built from these, and some basic operations on these (getfield, +, *, getindex etc)) via composition.

Consequently, whenever AD breaks we must have one of the following problems:

  1. we're missing, or have incorrectly defined, a primitive (either a non-struct type, or we hit C, or some super low-level bit-shift operations that define how + and * work that we somehow miss in our existing collection of primitives), or
  2. we try to AD mutating code, or
  3. a mutation-hiding rule produces something which is inconsistent with what AD would produce, were the function to which the rule applies written in a non-mutating manner (which is always possible), or
  4. a mutation-hiding rule is unable to consume something that AD has automatically produced.

I'm confident that all of the problems we're discussing here are of the 3 or 4 variety. We're not discussing problems 1 (because all of the AbstractArrays we're talking about are structs built from things we understand) or 2. I accept it's possible that I'm missing a problem above, but I'm confident that if I am, it's not something that's affecting us here.

3 and 4 are things we've been struggling with since Zygote was first written, because we never properly sat down and figured out how to map between the data structures that we know how to do AD on (structural tangents), and representations of those data structures that are convenient when writing generic rules (natural tangents). A huge amount of progress has been made on problem 3 in the last few months by yourself, Lyndon, and Miha, in the form of the form of the projection work. However, we still don't properly understand the relationship between structural and natural tangents, which means that we don't have a generic recipe to prevent problems 3 and 4 arising. They will continue to plague us until we develop this understanding, and do something with this understanding.

If we develop a proper understanding of the relationship between structural and natural tangents, we should be able to

  1. resolve problems 3 and 4 for all existing types for which we write rules, and
  2. write down a clear set of instructions and implement test utils for type-authors who wish to utilise natural tangents to take advantage of generically-typed rrules on their types.

It must be possible to do this because we're not trying to write rules for programmes that we don't already know how to differentiate in principle -- we just have technical issues around mutation.

edit: I phrased the first sentence poorly. I don't disagree that we need examples, I just think we've already got enough to try and understand the nature of the problem at hand.

edit2: I explicitly suggested that it matters whether an rrule is implemented for the sake of performance or to hide mutations in 3 and 4. This was incorrect -- 3 and 4 persist regardless the reason for implementing the rule.

willtebbutt commented 3 years ago

A good example from a thing that came up with some people at Invenia today @mzgubic @thomasgudjonwright :

sqrt(Diagonal(x))

for some positive-valued vector x. If you take a look at the implementation of this operation in the standard library, you'll see that it's really easy to AD. It's a nice non-mutating implementation, just requiring getfield, Diagonal, and broadcast to AD properly. Indeed, my understanding is that it did this, and produced a Tangent (structural tangent), as expected.

My understanding is that Tom was writing some code that called the above, but also some other operations on Diagonal, whose pullbacks happened to return a Diagonal (natural tangent). These interacted later on in the code, and produced an error. The impact of this was the use of several hours of people's time, trying to figure out what was going on, and how to fix it. The result is that Miha will be opening an ChainRules PR in the near future.

Tom or Miha: please correct me if I've misrepresented any of the above.

For me, this is a clear example of the kind of problem we shouldn't be having. AD did it's job perfectly in producing the result of sqrt(Diagonal(x)), but the structural tangent that was produced interacted poorly with other tangent representations, causing code to break.

Since the relationship between the structural and natural tangent of a Diagonal appears to be trivial, this is a really straightforward problem to patch (I believe @mzgubic is planning to make the operation return a natural tangent instead). But it's representative of a larger problem around the relationship between structural and natural tangents, and our (present) inability to gracefully handle interactions between the two, as we've discussed extensively in this issue.

thomasgudjonwright commented 3 years ago

Tom or Miha: please correct me if I've misrepresented any of the above.

Thanks for bringing this up, seems just about right!

mzgubic commented 3 years ago

One thing to add would be that if Zygote used ChainRules types internally we would not have this issue, since Tangent{Diagonal} and Diagonal can be added to each other. The problem comes when trying to add a NamedTuple to a Diagonal

mcabbott commented 3 years ago

TIL. Could or should more of these be made to work?

julia> d = Diagonal([1,2,3]);

julia> t = Tangent{typeof(d)}(diag=[4,5,6]);

julia> d + t
3×3 Diagonal{Int64, Vector{Int64}}:
 5  ⋅  ⋅
 ⋅  7  ⋅
 ⋅  ⋅  9

julia> t2 = Tangent{Any}(diag=[4,5,6]);

julia> d + t2  # could guess from d?
ERROR: ArgumentError: type does not have a definite number of fields

julia> u = UpperTriangular([1 2 3; 4 5 6; 7 8 9]);

julia> u + t  # could re-construct Diagonal anyway and then try + on naturals?
ERROR: MethodError: no method matching +(::UpperTriangular{Int64, Matrix{Int64}}, ::Tangent{Diagonal ...
mcabbott commented 3 years ago

To explore these things, I made some examples here: https://github.com/mcabbott/OddArrays.jl

In particular, this has examples of the sort we didn't have above. For r = Rotation(θ) there is a method for r * r which returns the same type. This r doesn't really have a nice "natural" gradient representation. But the generic rule for * expects one. To make it work I opted out of the rule by hand, but I wonder if that can be automated -- probably no generic rule should accept a Tangent.

one option is overloading +(::Tangent{FooArray}, ::AbstractArray})

I think this happens too late. The map from the array to the tangent in general depends on the original array. The easy cases where it doesn't, like Diagonal, seem to be the cases where we can happily standardise on the "natural" form.