google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.76k stars 2.72k forks source link

Gradient interoperability between JAX and other autodiff systems (TensorFlow, PyTorch, etc) #2154

Closed shoyer closed 1 month ago

shoyer commented 4 years ago

There's a nice duality between custom_gradient and jax.jvp/tf.gradients that is exactly what is needed for piping gradients back and forth, e.g.,

# untested!
import jax
import tensorflow as tf

def wrap_tf_in_jax(f):
    @jax.custom_gradient
    def g(x):
        y = f(x)
        return y, lambda g: tf.gradients(y, x, g)
    return g

def wrap_jax_in_tf(f):
    @tf.custom_gradient
    def g(x):
        return jax.vjp(f, x)
    return g

It would be nice to expose user facing functions for this.

Note that defining custom gradients rules isn't quite enough on its own. You also need to define a new JAX "primitive" in order to wrap non-traceable computations.

lukasheinrich commented 4 years ago

the same trick of using "custom ops" to wrap entire diffable programs works for e.g. wrapping TF into a PyTorch op. I agree it would be nice to add a clean API for that.

shoyer commented 4 years ago

Here's a working version of a prototype for passing gradient between TensorFlow 2 and JAX, which may be a useful point of reference: https://gist.github.com/shoyer/5f72853c2788e99e785f4737ee8a6ae1

mattjj commented 1 month ago

Thanks for the suggestion, but I think we'll leave this to user-defined wrappers rather than adding it to the JAX public API.