JuliaDiff / ChainRules.jl

forward and reverse mode automatic differentiation primitives for Julia Base + StdLibs
Other
435 stars 89 forks source link

An additional approach to implementing rules #338

Open willtebbutt opened 3 years ago

willtebbutt commented 3 years ago

There are basically two reasons to implement rules:

  1. to define AD. For example, you do have to tell an AD system somewhere how to differentiate addition and multiplication of floats,
  2. to make AD faster, without changing the semantics.

For 1 we obviously can't get around defining rules, however, for 2 we tend to implement rules in the same way as for 1 -- by completely over-riding any particular AD and just telling it how to differentiate a thing. However, one thing that we've not explored to a particularly great extent is re-writing code to make it more AD friendly, and then just saying "run AD on this".

Leaving aside concerns about the best way to achieve a code re-write for a minute, suppose that you wished to implement an rrule for *(::AbstractMatrix, ::Diagonal). LinearAlgebra implements this as follows:

(*)(A::AbstractMatrix, D::Diagonal) =
    rmul!(copyto!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A), D)

The problem from the perspective of a reverse-mode AD tool (that doesn't know how to handle mutation) is that the underlying implementation of this non-mutating operation is mutating. However, it is really quite clear how a non-mutating version of this operation could be implemented by looking at the definition of rmul!. Specifically, something like

A .* permutedims(D.diag)

This is the kind of code that we could plausible hope to run one of our current (or near-future) reverse-mode AD tools on, and have it do something sensible, whereas there was really no hope with LinearAlgebras definition.

Moreover, this kind of approach seems simpler for the rule-writer: rather than having to know how to differentiate a function, the rule-implementer just needs to know how to re-write the primal pass in a way that is more friendly towards AD.

This kind of approach is only valuable if there's functionality that could be easily re-written in an AD-friendly manner. My hypothesis is that there is lots of functionality in Base / the standard libraries that satisfies this because it was implemented by

  1. implementing a mutating version of a function (e.g. gemm!)
  2. implementing the non-mutating version of a function in terms of the mutating version. (e.g. gemm)

This could provide a very simple partial solution to #232 by alleviating the need for generic rules in favour of code re-writes which are much more straightforward to achieve and lets the AD system auto-generate appropriate cotangents.

Thoughts on the principle? We can discuss implementation details once we've established whether or not we basically like the idea.

sethaxen commented 3 years ago

This would be useful. Zygote uses this approach in a few places (with Zygote._pullback), and it can be convenient.

For another example, ^(A::Hermitian, p::Integer) is almost-differentiable by Zygote. It dispatches to power_by_squaring, which is non-mutating, so everything is great, except that for Hermitian{<:Complex}, it decides to be anal about explicitly realifying the diagonal (even though getindex(A::Hermitian, i, i) does that implicitly), which mutates, so now Zygote needs a whole implementation of a pullback of Base.power_by_squaring to support that function. It's terrible.

With the suggested approach, we would just reimplement ^(A::Hermitian{<:Complex}, p::Integer) to either leave the diagonal alone or make a copy to realify it. It would be short and simple, require much less work, and be Zygote-compatible.

But is this any easier than enabling a rule to call back into an AD, discussed in https://github.com/JuliaDiff/ChainRulesCore.jl/issues/68?

willtebbutt commented 3 years ago

But is this any easier than enabling a rule to call back into an AD, discussed in JuliaDiff/ChainRulesCore.jl#68?

It feels like it has a slightly different set of requirements that might be a bit simpler to handle, in a similar vein to https://github.com/JuliaDiff/ChainRulesCore.jl/issues/270 .

edit: I'm glad to hear that you found it easy to arrive at a use-case!

mzgubic commented 3 years ago

This would be useful. Zygote uses this approach in a few places (with Zygote._pullback), and it can be convenient.

Doesn't Zygote._pullback still require writing a backward pass explicitly?

sethaxen commented 3 years ago

Doesn't Zygote._pullback still require writing a backward pass explicitly?

No it doesn't. At some level you need to hit primitives that have defined adjoints/rrules, but you don't need to do it for the entire function. A really nice example of this was the now-removed workaround for norm: https://github.com/FluxML/Zygote.jl/blob/eee717ae56b424007a6f1587d21a9b5c89a7a92f/src/lib/array.jl#L424-L427 This call to _pullback just intercepted any call to norm and replaces it with a less-sophisticated but more Zygote-friendly function containing only function calls Zygote knows how to AD.

edit: I'm glad to hear that you found it easy to arrive at a use-case!

I have so many use-cases for this. 🙂

mzgubic commented 3 years ago

Oh I see, that's pretty neat! Thanks for the example

niklasschmitz commented 3 years ago

With the now-merged calling back into AD mechanism I found this pattern to be quite useful!

function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(complicated_f), args...)
    rrule_via_ad(config, simple_f, args...)
end

which could probably be done by a convenience macro like @alternative_primal complicated_f simple_f.

I'm currently trying it for differentiating through config constructors of physics solvers, which in our case are large structs with heterogeneous field types, few of which are differentiable. Other fields involve a lot of pre-computation including calls into non-Julia code, mutation, fft size heuristics and many sanity-check assertions, which is nice to explicitly bypass when constructing pullbacks.