Open willtebbutt opened 3 years ago
This would be useful. Zygote uses this approach in a few places (with Zygote._pullback
), and it can be convenient.
For another example, ^(A::Hermitian, p::Integer)
is almost-differentiable by Zygote. It dispatches to power_by_squaring
, which is non-mutating, so everything is great, except that for Hermitian{<:Complex}
, it decides to be anal about explicitly realifying the diagonal (even though getindex(A::Hermitian, i, i)
does that implicitly), which mutates, so now Zygote needs a whole implementation of a pullback of Base.power_by_squaring
to support that function. It's terrible.
With the suggested approach, we would just reimplement ^(A::Hermitian{<:Complex}, p::Integer)
to either leave the diagonal alone or make a copy to realify it. It would be short and simple, require much less work, and be Zygote-compatible.
But is this any easier than enabling a rule to call back into an AD, discussed in https://github.com/JuliaDiff/ChainRulesCore.jl/issues/68?
But is this any easier than enabling a rule to call back into an AD, discussed in JuliaDiff/ChainRulesCore.jl#68?
It feels like it has a slightly different set of requirements that might be a bit simpler to handle, in a similar vein to https://github.com/JuliaDiff/ChainRulesCore.jl/issues/270 .
edit: I'm glad to hear that you found it easy to arrive at a use-case!
This would be useful. Zygote uses this approach in a few places (with Zygote._pullback), and it can be convenient.
Doesn't Zygote._pullback
still require writing a backward pass explicitly?
Doesn't
Zygote._pullback
still require writing a backward pass explicitly?
No it doesn't. At some level you need to hit primitives that have defined adjoints/rrules, but you don't need to do it for the entire function. A really nice example of this was the now-removed workaround for norm
:
https://github.com/FluxML/Zygote.jl/blob/eee717ae56b424007a6f1587d21a9b5c89a7a92f/src/lib/array.jl#L424-L427
This call to _pullback
just intercepted any call to norm
and replaces it with a less-sophisticated but more Zygote-friendly function containing only function calls Zygote knows how to AD.
edit: I'm glad to hear that you found it easy to arrive at a use-case!
I have so many use-cases for this. 🙂
Oh I see, that's pretty neat! Thanks for the example
With the now-merged calling back into AD mechanism I found this pattern to be quite useful!
function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(complicated_f), args...)
rrule_via_ad(config, simple_f, args...)
end
which could probably be done by a convenience macro like @alternative_primal complicated_f simple_f
.
I'm currently trying it for differentiating through config constructors of physics solvers, which in our case are large structs with heterogeneous field types, few of which are differentiable. Other fields involve a lot of pre-computation including calls into non-Julia code, mutation, fft size heuristics and many sanity-check assertions, which is nice to explicitly bypass when constructing pullbacks.
There are basically two reasons to implement rules:
For 1 we obviously can't get around defining rules, however, for 2 we tend to implement rules in the same way as for 1 -- by completely over-riding any particular AD and just telling it how to differentiate a thing. However, one thing that we've not explored to a particularly great extent is re-writing code to make it more AD friendly, and then just saying "run AD on this".
Leaving aside concerns about the best way to achieve a code re-write for a minute, suppose that you wished to implement an
rrule
for*(::AbstractMatrix, ::Diagonal)
.LinearAlgebra
implements this as follows:The problem from the perspective of a reverse-mode AD tool (that doesn't know how to handle mutation) is that the underlying implementation of this non-mutating operation is mutating. However, it is really quite clear how a non-mutating version of this operation could be implemented by looking at the definition of
rmul!
. Specifically, something likeThis is the kind of code that we could plausible hope to run one of our current (or near-future) reverse-mode AD tools on, and have it do something sensible, whereas there was really no hope with
LinearAlgebra
s definition.Moreover, this kind of approach seems simpler for the rule-writer: rather than having to know how to differentiate a function, the rule-implementer just needs to know how to re-write the primal pass in a way that is more friendly towards AD.
This kind of approach is only valuable if there's functionality that could be easily re-written in an AD-friendly manner. My hypothesis is that there is lots of functionality in Base / the standard libraries that satisfies this because it was implemented by
gemm!
)gemm
)This could provide a very simple partial solution to #232 by alleviating the need for generic rules in favour of code re-writes which are much more straightforward to achieve and lets the AD system auto-generate appropriate cotangents.
Thoughts on the principle? We can discuss implementation details once we've established whether or not we basically like the idea.