Open froystig opened 2 years ago
Very nice, I'm excited about this!
Note that custom_jvp
already suffices for lax.custom_root
(supposing that the linear solve is transposable), but not lax.custom_linear_solve
(which need transposition).
Another major motivator is odeint: right now it has a custom_vjp rule, precluding forward mode differentiation. But we know how to decompose ODE differentiation into forward mode, partial evaluation, and transposition. We just need a way to register custom transposes for things like linear ODE solves!
I've added a section to the issue description that covers how custom transposition, once realized, might allow for a re-implementation of jax.custom_vjp
, either as it behaves today or with forward-mode AD support in some form.
I've added another section highlighting the linearity requirement on the target function.
Support custom transposition, i.e. the ability to register a custom "transpose rule" for any given function.
A function
f
marked for custom transposition and its transposet
take two arguments each. Their signatures are related as:where the
r
argument represents residuals. Typically, we would use the notationa -o b
to mean "a structurally linear functiona -> b
" and say that ift
is a transpose off
then:But we won't actually require structural linearity here, only "numerical" linearity. That's part of the very point of customization.
Example usage would look something like:
This is likely to subsume
linear_call
from #5781.Application: fewer primitives
Among other things,
lax.custom_root
andlax.custom_linear_solve
may no longer need to be primitives. When used together withjax.custom_jvp
, this would enable customization of both forward- and reverse-mode AD (custom_vjp
currently disallows forward-mode). To illustrate this with a simplified linear solve function:JAX derives VJPs from JVPs by linearization and transposition. In this case, the system will pick up the custom transpose of
solve
and the derived VJP will also carry out asolve
(against the transposed design matrix). More generally, forward-mode (JVP) behavior is altered bycustom_jvp
as usual, and reverse-mode (VJP) behavior is altered by any custom transposes in the dependence path of the tangent output (tx
).Application: upgrading custom VJPs
We could imagine re-implementing our
jax.custom_vjp
functionality on top ofcustom_jvp
andcustom_transpose
. A sketch:This recovers the current
custom_vjp
behavior. It also opens up possibilities for forward-mode AD support in the presence ofcustom_vjp
—something the current implementation doesn't support. There are at least two options for defining forward-mode behavior...Upgrade option 1: derive a JVP by transposing the VJP
This requires that the
bwd
function supplied tocustom_vjp
be structurally linear.Upgrade option 2: derive a JVP by linearizing the "forward" function
This requires that the
fwd
function supplied tocustom_vjp
be (forward-mode) automatically differentiable.(This repeats some work by interpreting
fwd
twice—once to compute primals and to linearize, and then again to grab the user-defined custom residuals. We could avoid this by using our "auxiliary output" machinery in AD to obtain those residuals concurrently with the linearization process, e.g. if we exposed thehax_aux
option ofad.linearize
up throughjax.linearize
.)On linearity assumptions
Although a
custom_transpose
'd function and its transpose rule need not be structurally linear, JAX may assume that they are nonetheless mathematically linear. This assumption is somewhat inevitable, and we should highlight it in documentation. An example consequence of this assumption is that if one writes:for some functions
f
andg
, then JAX's AD system may consider the following a valid JVP forf
:This is hopefully an unsurprising requirement, since the notion of transposition only applies to linear maps to begin with. That said, we could imagine applications that might willingly break the linearity requirement. An example is the following (arguably) natural implementation of "gradient clipping", as is somewhat common in neural network training:
If
clip
is itself ever automatically differentiated, the caller might be surprised that its derivative isthreshold
pastthreshold
, rather than 0, even though the derivative ofclip
was never explicitly customized.