Open adrhill opened 2 months ago
The trouble will be the same as for DifferentiateWith
: it won't work with all backends.
For instance there's no way to tell FiniteDiff not to differentiate through a function (unlike Zygote or Enzyme).
But here it is worse than for DifferentiateWith
because several backends may lead to different answers without erroring. At least, when it doesn't error, DifferentiateWith
gives the same output regardless of whether the custom chain rule is hit.
So I guess there are 2 different use cases here:
It seems like (1) is tricky to support for the reasons in https://github.com/gdalle/DifferentiationInterface.jl/issues/415#issuecomment-2296967975, but (2) should be safe, right? Since with e.g. FiniteDifferences we wouldn't want or need the behavior to change at all, while with operator-overloading ADs we'd want to strip away types and with source-to-source ADs we'd want to tell them to just compile the usual function.
I guess you're right for 2. Essentially this would only concern
@non_differentiable
Can you think of any other backend where it would work?
With Enzyme, I think one could use inactive
: https://enzyme.mit.edu/index.fcgi/julia/stable/generated/custom_rule/#Marking-functions-inactive. For source-to-source I think one would want it to also work for Tapir if possible.
It should ideally impact all of the operator-overloading ADs as well, at least ForwardDiff, ReverseDiff, and Tracker. Probably Symbolics also.
Contributions are welcome on this, I'm not yet comfortable enough with metaprogramming to try it alone
Inactive's semantics are different from chainrules. Marking something as inactive in Enzyme.jl says that both the operation itself doesn't transfer derivative information, but also that no value produced by the function call could not contain differentiable data in the future.
e.g. allocating an empty vector would be legal to mark as chainrules inactive, but not enzyme inactive.
In that sense enzyme inactive implies chainrules inactive (I think).
Enzyme's activity analysis also supports the ability to specify that just the instruction (e.g. function call) doesn't transfer derivative information without discussing the return, but we haven't added syntactic sugar to Julia for that yet.
This could be implemented in a similar way to
DifferentiateWith
.