jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.99k stars 2.75k forks source link

Defining both custom_vjp and custom_jvp #3996

Open thisiscam opened 4 years ago

thisiscam commented 4 years ago

I was wondering is it possible to define both custom vjp and jvp for a function?

from jax import custom_vjp, custom_jvp, jacfwd, jacrev

def f(x):
  return x

f = custom_jvp(f)
f.defjvp(lambda *args: args)

f = custom_vjp(f)

def fwd(*primals):
  return (primals, 0.)

def bwd(res, ct):
  return ct
f.defvjp(fwd, bwd)
print(jacfwd(f)(1.))
print(jacrev(f)(1.))

Above fails with

TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.
mattjj commented 4 years ago

Thanks for the question!

No, that's not currently possible. It's documented in the custom derivatives tutorial (see the subsection "Forward-mode autodiff cannot be used on the jax.custom_vjp function and will raise an error" for more).

Can you say more about why you might need to define both?

thisiscam commented 4 years ago

Aha, didn't see that sentence in the tutorial.

My use case is quite intricate and very particular. I was trying to fix https://github.com/tensorflow/probability/issues/1045.

The idea is that oryx defines this "tagging" transformation; what I want is for that tag to preserve after jacfwd/jacrev/grad transformations. So I wanted to override both the jacfwd and jacrev for the tagged value (actually, a function that returns those tagged values); the overridden behavior is to just tag the gradient values. It's a bit hacky IMO, but it's at least an attempt at this.

jekbradbury commented 4 years ago

There are a few ways you could solve this, but one would be to create your own version of Oryx's sow primitive that has different behavior under JVP and transpose transformations (another would be to expose more control over how sow itself behaves, along the lines of Oryx's existing support for controlling how sow behaves under a scan). CC @sharadmv

(Separately, the most straightforward answer to "how do I have a custom JVP and VJP at the same time" is to create a full-blown Primitive.)

sharadmv commented 4 years ago

What are your thoughts on rather than tying harvest to sow exclusively, it looks inside a set of primitives that all have name/tag/mode params. This could enable users writing their own with specialized semantics however they want, then adding their primitive to the set.

Also, I realize now this discussion is better fit in the Oryx thread.

albi3ro commented 1 year ago

I'd like to allow request the ability to define both the custom_vjp and custom_jvp.

I work on PennyLane, where we register derivatives computed on quantum computers with JAX.

Currently, we compute the full jacobian under a pure_callback, then let JAX trace the computation of the jacobian product. This allows us to provide both the jvp and vjp.

If we were to use a pure_callback around either the jvp or the vjp instead, we could be more efficient with our resources. Unfortunately, we would then lose the ability to support both modes.

So, our lives would be much easier if we could register both the jvp and vjp and compute them both under a jax.pure_callback.

ToshiyukiBandai commented 1 year ago

I also would like to have both custom_vjp and custom_vjp. I want to compute the Hessian matrix of a model, so it is good to have both rather than using grad twice.

samuelpmish commented 1 year ago

+1 for allowing both custom vjp and jvp.


In my application space, many of the important operations are related to large matrices that are never actually formed explicitly (because they are much too large to fit in memory). For example, if $u, v, x$ are column vectors with 1,000,000 components, the function

$$ f(x) \coloneqq u (v^\top x) = A x, \qquad A \coloneqq u \\; v^\top $$

has simple implementations of its vjp ($dy^\top A = (dy^\top u) v^\top $) and jvp ($A \\\; dx = u (v^\top dx)$), but forming $A$ from the jvp definition and transposing it to implement the vjp is incredibly inefficient, if not impossible.

This example is trivial for the sake of demonstration, but the underlying "matrix-free" idea is important in Krylov methods and HPC applications.


edit: to clarify, I'm mostly interested in being able to specify both vjp and jvp to make "black box" functions (e.g. C++ libraries w/ python bindings) work more seamlessly in JAX. JAX's existing strategy of tracing the jvp and generating vjp from it works great-- provided that the implementation is traceable. But it's not always practical to reimplement large libraries in JAX to make them directly traceable.

ToshiyukiBandai commented 11 months ago

For those who are interested in this topic, JAX team published a paper on this. It explains the underlying philosophy.

https://dl.acm.org/doi/abs/10.1145/3571236

mattjj commented 2 months ago

Believe it or not, @dfm is about to fix this issue!

carlosgmartin commented 5 days ago

@mattjj Do you have a link to the relevant issue/PR? And should this be closed?

dfm commented 5 days ago

@carlosgmartin — I think @mattjj might have overstated the "about to fix" part of his comment :D. We're still iterating on design questions, and realistically I don't think that a good solution will land before https://github.com/google/jax/pull/23299. There is a proof-of-concept in https://github.com/dfm/jax/tree/custom-ad-2, but I don't have a timeline for when/what will be merged. Always happy to help with specific use cases while we wait!