JuliaDiff / ChainRulesCore.jl

AD-backend agnostic system defining custom forward and reverse mode rules. This is the light weight core to allow you to define rules for your functions in your packages, without depending on any particular AD system.
Other
256 stars 62 forks source link

The Extensibility Problem, for propagator closures #53

Closed willtebbutt closed 3 years ago

willtebbutt commented 5 years ago

Here's a thing that isn't currently possible and is, I believe, something that we might actually want to care about. Consider the pullback for AbstractMatrix multiplication:

function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
    function times_pullback(Ȳ)
        return (NO_FIELDS, @thunk(Ȳ * B'), @thunk(A' * Ȳ))
    end
    return A * B, times_pullback
end

Provided that is itself an AbstractMatrix, for which * with other matrices will be correctly defined assuming getindex is correctly defined, something correct will happen even if it's slow.

Now consider the case that is a NamedTuple, possibly because Y is some non-Matrix AbstractMatrix. Now what happens? The above breaks: * isn't defined for NamedTuples, nor is it possible to extend times_pullback to handle from outside the original rrule definition. One's only recourse is to add a completely new definition of pullback for * with AbstractTypeofA and AbstractTypeofB one expects to see with, which itself presupposes a method of pullback(typeof(*), ::AbstractTypeofA, ::AbstractTypeofB) doesn't already exist, in which case no option is available but to modify the existing method.

Phrased differently, the current design requires each rrule must be implemented to handle every possible type of that it might ever see. This is clearly an unreasonable requirement because Julia permits the creation of new types, and the input types to pullback to do not uniquely specify the type of . Assuming this unreasonably requirement unmet, we are left with two options when the above is encountered:

  1. Modify the original rrule to handle the new type encountered (the bad-ness of this is assumed self-evident)
  2. Write a new rrule specialised to different types of A and B. This really isn't great from a code re-use perspective. * is not a pathological case because the forwards-pass is quite straightforward, but other cases are worse.

This problem appears to manifest itself in cases where the forwards-pass is perfectly good for multiple types, but the reverse-pass requires care. For example Diagonal * Matrix: the forwards-pass and data required on the reverse-pass is no different than Matrix * Matrix, but the reverse-pass implementation is necessarily quite different.

This lack of extensibility is a direct consequence of the (value, back) = pullback(...) design choice that ChainRules / Zygote make. Nabla made a slightly different design choice in which the forwards- and reverse- bits of a pullback were separate functions, so you could extend things. This design doesn't share the extensibility issue that the ChainRules / Zygote style presents, but equally doesn't immediately enable the same sharing of state on the forwards- and reverse-passes.

One possible resolution would be to adopt the separated forwards- and reverse- passes chosen in Nabla, and allow an arbitrary communication object to be shared between the forwards- and reverse- passes.

function forward(::typeof(*), A::AbstractMatrix, B::AbstractMatrix)
    return A * B, (signature=(typeof(*), typeof(A), typeof(B)), A=A, B=B)
end

function pullback(::NamedTuple{(:signature, :A, :B), (Tuple{typeof(*), AbstractMatrix, AbstractMatrix}, AbstractMatrix, AbstractMatrix)}, Ȳ::AbstractMatrix)
    # do stuff
end
function pullback(<NamedTuple stuff>, Ȳ::NamedTuple{<some field names>})
    # do other stuff
end

A forward call would therefore not return a closure, but rather whatever intermediate data is deemed by the implementer to be important for the reverse-pass. pullback is then called with the appropriate signature to evaluate the adjoint. This interface is somewhat more verbose than the closure-based interface due to the need to copy-paste the signature all over the place, although this may be alleviated with some careful metaprogramming tooling. I would anticipate that we would also see an improvement in stack-trace readability, since we would get direct calls to a pullback function, with the types of forwards arguments placed prominently.

In summary, this change buys us the ability to extend reverse-pass behaviour using multiple dispatch, and hence code-reuse, at the expense of increased verbosity and the need to explicitly specify the data from the forwards-pass that may be required on the reverse.

This isn't something that needs resolving immediately, but I feel it should be given some consideration so that we can at least be aware that this is an issue we're choosing to ignore if nothing is done about it.

Side note: this appears analogous to be similar to the expression problem, which you can consult Stefan's JuliaCon talk on. Specifically, that we can't define new methods of back for existing pullbacks is (I think) analogous to not being able to define new methods that extend the functionality of existing types.

oxinabox commented 5 years ago

I think this is something to keep in mind. Having rewritten the core of ChainRules once this month, I'ld rather not do it again for a few minor versions. Also I will be wanting to put out the metaprogramming helper macros #44, before we tackle this anyway. Since solving it will likely need to be done by using them.

This also relates to the fact you can't actually pass AbstractDifferentials as input to most pullbacks, they need to be externed normally anyway.

Interesting idea: if we did go to pullback as a global function, taking a signature, and (Ȳ) and some extra information, we might be able to encode that extra information as a closure that has the default implementation.

Ideally, we would dispatch the rrule itself based on what would be passed to its pullback. Since we might also want to do the forward-pass differently and capture different information to give to that pullback.

But that requires global information, some of which violates the halting problem. So we can't do that, but we should think about how useful the ability to add new pullbacks without adding new rrules is.

The input to the pullback (Ȳ) has to be a very similar type (need to be able to subtract them, I think?) to the output of the forward pass (Y). And we haven't really implemented #8 yet. So I don't entirely know what that looks like, and if for example we will end up with NamedTuples that want to be used as Matrix's

Definately this issue is one to think on

oxinabox commented 5 years ago

How about this. Rather than using out pullback directly as the canonial way to propagrate gradient, we have a function backpropagate, that takes the sig, the pullback, and the Ȳ. By default it just calls the pullback but it can be overloaded.

It can solve a few problems:

  1. externing befause calling pullbacks
  2. Not wasting time calling pullback of Ȳ is Zero.
  3. Overloading backpropagate for different Ȳ types. And just using the pullback closure as a collection of fields.

We also need a function forward to help track signature types.

function forward(f::F, args...)
    ret, default_pullback = rrule(f, args...)
    sig = Tuple{F, typeof.(args)...}
    return ret, default_pullback, sig
end 

So this is what it might looklike:

backpropagate(sig, default_pullback, Ȳ) = default_pullback(extern(Ȳ))
backpropagate(sig, pullback_info, Ȳ::Zero) = Zero()

function backpropagate(sig::Tuple{*, Special1, Special2}, pullback_info, Ȳ)
    # Even though pullback_info is a closure, we never call it,
    # it might as well be a NamedTuple.

    Ā = @thunk(g(Ȳ, pullback_info.A))
    B̄ = h(Ȳ, pullback_info.B)
    return NO_FIELDS, Ā, B̄
end

Bonus fact, that may or maynot apply to storing sig as part of namedtuple storing it as a tuple means getting the covarient types.

It might kinda be part of replacing accumulate, as I am not sure how that works in the new world. (it isn't broken, i am just unsure how useful it is)

MikeInnes commented 5 years ago

I'd be interested in more specifics of the use case for this and the kind of extensibility you need. My main issue with a separate pullback function is that it simulates closures (i.e. a bundle of data + code) anyway; you're going to end up with something equivalent but much less nice to use.

function forward(::typeof(*), A::AbstractMatrix, B::AbstractMatrix)
    return A * B, C -> pullback((signature=(typeof(*), typeof(A), typeof(B)), A=A, B=B), C)
end

Of course, writing things out this way isn't that helpful if you have to do it for every rule. But the only real difference is that the pullback has a name, which gives you an interface to overload it. We could get the same effect with something like (with appropriate sugar)

P = pullback_name(Tuple{typeof(*),AbstractMatrix,AbstractMatrix})
(::typeof(P))(Y) = ...

However, while this would solve the problem in a general way, I'm sceptical that even this is really needed. Why can't you make the adjoint a named type, rather than a named tuple, and overload * and +? That's all you need to support almost every rule in one go, and it seems unlikely that you'd want to change the actual meaning of the pullback, as opposed to just making linear maps work on your custom type.

oxinabox commented 5 years ago

Why can't you make the adjoint a named type, rather than a named tuple, and overload * and +?

That will be the case with #8 , I am calling that DNamedTuple.


as opposed to just making linear maps work on your custom type.

The main case is that being a linear map practically is not enough for some pullbacks. E.g. some of the LinearAlgebra ones want you to support various BLAS operations, or factorizations. So conceptually I can imagine some types that might show up at some point don't support the operations in the default pullbacks but have their own wierd ways to do the same thing. Or that for them the same thing is actually way faster if expressed in a different way.

willtebbutt commented 5 years ago

Why can't you make the adjoint a named type, rather than a named tuple, and overload * and +?

This is tricky to do when mixing-and-matching custom adjoints with automatically derived ones as automatically derived rules will always produce a NamedTuple.

MikeInnes commented 5 years ago

This is tricky to do when mixing-and-matching custom adjoints with automatically derived ones as automatically derived rules will always produce a NamedTuple.

Sure, but there's a finite number of such adjoints (I think just getproperty – are there others?) and an unlimited number of pullbacks that would otherwise need to have their behaviour overloaded.

Customising your adjoint type is something we can define a clear interface for and it'll work with custom adjoints defined in outside packages, whereas if you manually override each pullback it's only going to work with the set you specifically overloaded.

oxinabox commented 5 years ago

I am in favour of waiting and seeing.

We will certainly have a way to convert a NamedTuple to a DNamedTuple as part of #8 anyway. if we need more we will deal with that then.

Particularly, since under my plan the extra info you need will be housed in the default closure propagator anyway. (Big fan of this notion still, since it makes it easy for rule writers to include the relevant information. Since they use it there and then)

oxinabox commented 4 years ago

Here is an instance of this in the wild for Zygote https://github.com/FluxML/Zygote.jl/blob/9af896e5eb9539adf7161ca3cadf4af9dfce0723/src/lib/array.jl#L388-L393

It special cases AbstractMatrix and NamedTuple but what is to say some future package won't want similar kind of special treatment?

So I think we should do https://github.com/JuliaDiff/ChainRulesCore.jl/issues/53#issuecomment-533833058

oxinabox commented 3 years ago

Because of math reasons, it is very rare to get a unexpected type being passed to the pullback. We thus do not in general have an extensibility problem. The tangent type provided is pretty much determined by the output primal type. Which in turn is determined by the primal input types for type stable functions.

Thus generally one just adds another rrule

There is a bit more to this story w.r.t. arrays but we have a bit more of that story encoded eg. in ProjectTo