Open thisiscam opened 4 years ago
Thanks for the question!
No, that's not currently possible. It's documented in the custom derivatives tutorial (see the subsection "Forward-mode autodiff cannot be used on the jax.custom_vjp
function and will raise an error" for more).
Can you say more about why you might need to define both?
Aha, didn't see that sentence in the tutorial.
My use case is quite intricate and very particular. I was trying to fix https://github.com/tensorflow/probability/issues/1045.
The idea is that oryx defines this "tagging" transformation; what I want is for that tag to preserve after jacfwd/jacrev/grad transformations. So I wanted to override both the jacfwd and jacrev for the tagged value (actually, a function that returns those tagged values); the overridden behavior is to just tag the gradient values. It's a bit hacky IMO, but it's at least an attempt at this.
There are a few ways you could solve this, but one would be to create your own version of Oryx's sow
primitive that has different behavior under JVP and transpose transformations (another would be to expose more control over how sow
itself behaves, along the lines of Oryx's existing support for controlling how sow
behaves under a scan
).
CC @sharadmv
(Separately, the most straightforward answer to "how do I have a custom JVP and VJP at the same time" is to create a full-blown Primitive
.)
What are your thoughts on rather than tying harvest
to sow
exclusively, it looks inside a set of primitives that all have name/tag/mode
params. This could enable users writing their own with specialized semantics however they want, then adding their primitive to the set.
Also, I realize now this discussion is better fit in the Oryx thread.
I'd like to allow request the ability to define both the custom_vjp
and custom_jvp
.
I work on PennyLane, where we register derivatives computed on quantum computers with JAX.
Currently, we compute the full jacobian under a pure_callback
, then let JAX trace the computation of the jacobian product. This allows us to provide both the jvp and vjp.
If we were to use a pure_callback
around either the jvp or the vjp instead, we could be more efficient with our resources. Unfortunately, we would then lose the ability to support both modes.
So, our lives would be much easier if we could register both the jvp and vjp and compute them both under a jax.pure_callback
.
I also would like to have both custom_vjp and custom_vjp. I want to compute the Hessian matrix of a model, so it is good to have both rather than using grad
twice.
+1 for allowing both custom vjp and jvp.
In my application space, many of the important operations are related to large matrices that are never actually formed explicitly (because they are much too large to fit in memory). For example, if $u, v, x$ are column vectors with 1,000,000 components, the function
$$ f(x) \coloneqq u (v^\top x) = A x, \qquad A \coloneqq u \\; v^\top $$
has simple implementations of its vjp ($dy^\top A = (dy^\top u) v^\top $) and jvp ($A \\\; dx = u (v^\top dx)
$), but forming $A$ from the jvp definition and transposing it to implement the vjp is incredibly inefficient, if not impossible.
This example is trivial for the sake of demonstration, but the underlying "matrix-free" idea is important in Krylov methods and HPC applications.
edit: to clarify, I'm mostly interested in being able to specify both vjp
and jvp
to make "black box" functions (e.g. C++ libraries w/ python bindings) work more seamlessly in JAX. JAX's existing strategy of tracing the jvp
and generating vjp
from it works great-- provided that the implementation is traceable. But it's not always practical to reimplement large libraries in JAX to make them directly traceable.
For those who are interested in this topic, JAX team published a paper on this. It explains the underlying philosophy.
Believe it or not, @dfm is about to fix this issue!
@mattjj Do you have a link to the relevant issue/PR? And should this be closed?
@carlosgmartin — I think @mattjj might have overstated the "about to fix" part of his comment :D. We're still iterating on design questions, and realistically I don't think that a good solution will land before https://github.com/google/jax/pull/23299. There is a proof-of-concept in https://github.com/dfm/jax/tree/custom-ad-2, but I don't have a timeline for when/what will be merged. Always happy to help with specific use cases while we wait!
I was wondering is it possible to define both custom vjp and jvp for a function?
Above fails with