gdalle / DifferentiationInterface.jl

An interface to various automatic differentiation backends in Julia.
https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface
MIT License
163 stars 12 forks source link

Add common interface to mark functions as non-differentiable across backends #415

Open adrhill opened 3 weeks ago

adrhill commented 3 weeks ago

This could be implemented in a similar way to DifferentiateWith.

gdalle commented 3 weeks ago

The trouble will be the same as for DifferentiateWith: it won't work with all backends. For instance there's no way to tell FiniteDiff not to differentiate through a function (unlike Zygote or Enzyme).

But here it is worse than for DifferentiateWith because several backends may lead to different answers without erroring. At least, when it doesn't error, DifferentiateWith gives the same output regardless of whether the custom chain rule is hit.

sethaxen commented 3 weeks ago

So I guess there are 2 different use cases here:

  1. Marking a function that actually would otherwise make a non-zero contribution to the differential as non-differentiable, thus changing the differential.
  2. Marking as non-differentiable a function call that cannot contribute to the differential to avoid either performance issues or errors raised by the AD backend due to unsupported language features.

It seems like (1) is tricky to support for the reasons in https://github.com/gdalle/DifferentiationInterface.jl/issues/415#issuecomment-2296967975, but (2) should be safe, right? Since with e.g. FiniteDifferences we wouldn't want or need the behavior to change at all, while with operator-overloading ADs we'd want to strip away types and with source-to-source ADs we'd want to tell them to just compile the usual function.

gdalle commented 3 weeks ago

I guess you're right for 2. Essentially this would only concern

Can you think of any other backend where it would work?

sethaxen commented 3 weeks ago

With Enzyme, I think one could use inactive: https://enzyme.mit.edu/index.fcgi/julia/stable/generated/custom_rule/#Marking-functions-inactive. For source-to-source I think one would want it to also work for Tapir if possible.

It should ideally impact all of the operator-overloading ADs as well, at least ForwardDiff, ReverseDiff, and Tracker. Probably Symbolics also.

gdalle commented 3 weeks ago

Contributions are welcome on this, I'm not yet comfortable enough with metaprogramming to try it alone

wsmoses commented 3 weeks ago

Inactive's semantics are different from chainrules. Marking something as inactive in Enzyme.jl says that both the operation itself doesn't transfer derivative information, but also that no value produced by the function call could not contain differentiable data in the future.

e.g. allocating an empty vector would be legal to mark as chainrules inactive, but not enzyme inactive.

In that sense enzyme inactive implies chainrules inactive (I think).

Enzyme's activity analysis also supports the ability to specify that just the instruction (e.g. function call) doesn't transfer derivative information without discussing the return, but we haven't added syntactic sugar to Julia for that yet.