google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.21k stars 2.68k forks source link

JEP: Allow finite difference method for non-jax differentiable functions #15425

Open mbmccoy opened 1 year ago

mbmccoy commented 1 year ago

Motivation: Make JAX relevant to many more users

JAX provides a powerful general-purpose tool for automatic differentiation, but it usually requires that users write code that is JAX-traceable end-to-end.

Significant numbers of scientific and industrial applications involve large, legacy codebases where the lift to transfer the system to end-to-end JAX is prohibitively high. In other cases, users are tied into proprietary software, or the underlying software is not written in python, and also find themselves unable to readily convert the underlying code to JAX.

Without JAX, the standard method for performing optimization involves computing derivatives using finite difference methods. While these can be integrated into JAX using custom functions, the process is cumbersome, which significantly limits the set of users able to integrate JAX into their work.

This JEP proposes a simple method for computing numerical derivatives in JAX. I expect that this change would expand the potential user base of JAX substantially, and could drive adoption of JAX across both academia and industry.

Proposal: A decorator that computes JAX derivatives using finite differences

Let's start with an example.

from jax import grad
from jax import numpy as jnp
from jax.experimental.finite_difference import jax_finite_difference
import numpy as old_np  # Not jax-traceable

@jax_finite_difference
def rosenbach2(x, y):
    """Compute the Rosenbach function for two variables."""
    return old_np.pow(1-x, 2) + 100*old_np.pow(y-old_np.pow(x, 2), 2)

def rosenbach3(x, y, z):
    """Compute the Rosenbach function for three variables."""
    return rosenbach2(x, y) + rosenbach2(y, z)

value, grad = jax.value_and_grad(rosenbach3)(1., 2., 3.) 

By wrapping the function rosenbach2 in jax_finite_difference, it will become completely compatible with JAX's automatic differentiation tooling, and works with other JAX primitives such as vmap.

Additional options will be available for power users who may want to specify the step size, or forward vs center vs backward mode.

This is feasible.

I have working, tested code that does the above for any function that accepts and returns JAX Arrays. If there is interest in this JEP, I will happily make a PR.

Limits of this JEP

This proposal will not support XLA out of the box

The initial proposal does not aim to support XLA for finite differences. While it should be possible to overcome this limitation using a JAX Foreign Function Interface (FFI) [Issue #12632, PR], it would be best to wait until the FFI is finalized before implementing XLA for finite differences.

The downsides of float32 increase with finite differences

Using single-precision (32-bit) floating point numbers in finite differences may lead to unacceptably large errors in many cases. While this is not a foregone conclusion—many functions can be differentiated just fine with 32-bit floating point—we probably want to plan for mitigation strategies, e.g.,

The second strategy would likely be more important when using FFI in conjunction with XLA in later work. At this stage a warning may be all that's needed.

Related JEPs

The proposed Foreign Function Interface [Issue #12632, PR] will provide a method that allows JAX code to call out to external code code in the course of derivative computation. However, it does not create a method for computing derivatives—those must still be defined by the user.

However, we expect that the FFI combined with our finite-difference method would enable "the dream": nearly-arbitrary user code fully-integrated with JAX using a single decorator.

froystig commented 1 year ago

Thanks for the proposal! We've discussed something along these lines amongst ourselves at various points. It's a good idea!

Do you have an idea for the API you'd propose? When this has come up before, our thought was to make it a convenience wrapper around jax.custom_jvp. That suggests an implementation using custom_jvp as well.

mbmccoy commented 1 year ago

I actually implemented it using full-on custom primitives following this guide.

Since it seems like there's interest, I'll make the PR.

mattjj commented 1 year ago

@mbmccoy any reason not to use a custom_jvp or custom_vjp? Those are a lot simpler than a custom primitive, and I think it'd be the preferred approach unless I'm missing something.

froystig commented 1 year ago

In particular custom_jvp would support both forward- and reverse-mode AD.

ASEM000 commented 1 year ago

Hello,

for the context of API design, I have a small WIP library for finite difference, I implemented a finite difference version of jax.grad with couple of options

colab


import jax 
from jax import numpy as jnp
import numpy as old_np  # Not jax-traceable
import finitediffx as fdx
import functools as ft 
from jax.experimental import enable_x64

with enable_x64():

    @fdx.fgrad
    @fdx.fgrad
    def np_rosenbach2(x, y):
        """Compute the Rosenbach function for two variables."""
        return old_np.power(1-x, 2) + 100*old_np.power(y-old_np.power(x, 2), 2)

    @ft.partial(fdx.fgrad, derivative=2)
    def np2_rosenbach2(x, y):
        """Compute the Rosenbach function for two variables."""
        return old_np.power(1-x, 2) + 100*old_np.power(y-old_np.power(x, 2), 2)

    @jax.grad 
    @jax.grad
    def jnp_rosenbach2(x, y):
        """Compute the Rosenbach function for two variables."""
        return jnp.power(1-x, 2) + 100*jnp.power(y-jnp.power(x, 2), 2)

    print(np_rosenbach2(1.,2.))
    print(np_rosenbach2(1.,2.))
    print(jnp_rosenbach2(1., 2.))

402.0000951997936
402.0000951997936
402.0
ASEM000 commented 1 year ago

As far as I understand we can also use fdx.fgrad with custom_jvp and pure_callback to make non-traceable code work with jax transformations.


import functools as ft

import jax
import jax.numpy as jnp
import numpy as onp

import finitediffx as fdx

def wrap_pure_callback(func):
    @ft.wraps(func)
    def wrapper(*args, **kwargs):
        args = [jnp.asarray(arg) for arg in args]
        func_ = lambda *args, **kwargs: func(*args, **kwargs).astype(args[0].dtype)
        result_shape_dtype = jax.ShapeDtypeStruct(
            shape=jnp.broadcast_shapes(*[arg.shape for arg in args]),
            dtype=args[0].dtype,
        )
        return jax.pure_callback(
            func_, result_shape_dtype, *args, **kwargs, vectorized=True
        )

    return wrapper

def define_finitdiff_jvp(func):
    func = jax.custom_jvp(func)

    @func.defjvp
    def func_jvp(primals, tangents):
        primal_out = func(*primals)
        tangent_out = sum(
            fdx.fgrad(func, argnums=i)(*primals) * dot for i, dot in enumerate(tangents)
        )
        return jnp.array(primal_out), jnp.array(tangent_out)

    return func

@jax.jit
@define_finitdiff_jvp
@wrap_pure_callback
def np_rosenbach2(x, y):
    """Compute the Rosenbach function for two variables."""
    return onp.power(1 - x, 2) + 100 * onp.power(y - onp.power(x, 2), 2)

@jax.jit
def jnp_rosenbach2(x, y):
    """Compute the Rosenbach function for two variables."""
    return jnp.power(1 - x, 2) + 100 * jnp.power(y - jnp.power(x, 2), 2)

print(jax.value_and_grad(np_rosenbach2, argnums=0)(1.0, 2.0))
print(jax.value_and_grad(jnp_rosenbach2, argnums=0)(1.0, 2.0))

print(jax.value_and_grad(np_rosenbach2, argnums=1)(1.0, 2.0))
print(jax.value_and_grad(jnp_rosenbach2, argnums=1)(1.0, 2.0))

print(jax.vmap(jax.grad(np_rosenbach2), in_axes=(0, None))(jnp.array([1.0, 2.0, 3.0, 0.2]), 2.0))
print(jax.vmap(jax.grad(jnp_rosenbach2), in_axes=(0, None))(jnp.array([1.0, 2.0, 3.0, 0.2]), 2.0))

(Array(100., dtype=float32), Array(-399.9948, dtype=float32, weak_type=True))
(Array(100., dtype=float32, weak_type=True), Array(-400., dtype=float32, weak_type=True))
(Array(100., dtype=float32), Array(199.97772, dtype=float32, weak_type=True))
(Array(100., dtype=float32, weak_type=True), Array(200., dtype=float32, weak_type=True))
[-399.9948  1601.8411  8403.304   -158.45016]
[-400.      1602.      8404.      -158.40001]
mbmccoy commented 1 year ago

As far as I understand

So understated! Seems like you've written a package that does most of the work here. :)

Given this, is an implementation of finite-differences in the core JAX package (e.g., under jax.experimental) desired? I'm inclined to think that this would be a very valuable addition to the JAX ecosystem, either here or some sister project, because of how useful I'm finding the tool when using third party scientific libraries.

FYI, I'm happy to collaborate on a PR @ASEM000 if that's of interest. I'll try to get something up in the next day or so for more comment---I've had some busy days in the last week.

@mattjj I'll look into using custom_vjp and custom_jvp—I recall having a reason I didn't use those to start with, but it's possible that I just missed something.

mbmccoy commented 1 year ago

To be concrete about the API @froystig, here are some guiding principles that I have in mind:

  1. A reasonable name. Personally, I think a decorator named jaxify is appropriate because it gets at "what it means for most users", that is, make their existing code work with JAX. Other names like jax_finite_difference are also OK, but they are based on implementation and are a bit more technical.
  2. One-line decorator that "just works" for 80% of cases, and provides sensible error messages for the most common issues (e.g., not abiding by the API).
  3. More powerful options for ~15-19% of the remaining cases. For example:
    • Step size control
    • Mode choice (e.g., central, forward, and backward modes)
    • Allowing different steps sizes or modes per argument.
  4. Reasonably efficient, so that, e.g., forward mode uses significantly fewer function evaluations than central mode.
  5. Consistent with the existing API. I like the idea that it has the same requirements as pure_callback (as @ASEM000 helpfully suggested).
  6. It might be nice to warn users when the use of this decorator has kept them from having full hardware acceleration, though that may be best left as a feature for pure_callback. (Not sure if this is easy or hard.)
  7. Well-documented, e.g., a notebook with documentation for use.

API Examples

The "80%" use case

Most use cases should start with a simple decorator:

import jax
from jax.experimental.jaxify import jaxify
from jax.experimental import enable_x64

@jaxify
def my_func(array1, array2):
   return some_complex_math(array1, array2)

print(jax.value_and_grad(my_func)(x, y))  # Warn about not using 64-bit math

with enable_x64():
    jax.value_and_grad(my_func)(x, y)  # No warning

Power use cases


# The user wants control over the step size
@jaxify(step_size=1e-9, mode="forward")
def my_func(array1, array2):
   return some_complex_math(array1, array2)

# The user wants per-argument control over the step size
@jaxify(step_size=(1e-9, 1e-3), mode=("forward", "center"))
def my_func(array1, array2):
   return some_complex_math(array1, array2)
froystig commented 1 year ago

@mbmccoy – Why bundle together (a) setting up a derivative rule based on FD with (b) setting up a pure_callback call?

mbmccoy commented 1 year ago

There are two reasons I can think of to support FD for the same class of functions that pure_callback supports: equivalence and consistency. (Note that I'm suggesting that we use the same API for the functions, nothing more.)

Equivalence: The class of functions theoretically supportable by a generic FD technique are precisely those that are theoretically supportable by a pure_callback mechansim (that is, pure functions that accept and return numpy arrays).

Think about it: almost by definition, the functions we'd want to finite-difference are not supported directly within JAX's JIT, so computing their values during a generic finite-difference routine requires the use of some mechanism like pure_callback. If we manage to write an FD routine that supports more functions than pure_callback, we could backport the functionality to pure_callback using, for example, the value part of jax.value_and_grad.

Conversely, it's pretty easy to see that—at least in principle—we can write a wrapper that will apply finite differences to pure functions that accept and return numpy arrays. That's, of course, the challenge I've set out for myself here.

Consistency: Given the close link between the sets of functions supportable by pure_callback and those supportable by a generic finite difference routine, we should probably just insist that they are the same and test for it. This will allow us to leverage pure_callback and keep the overall API complexity roughly constant.

Example documentation

"""Make a function differentiable with JAX using finite differences.

[...]

The function must be pure (that it, side-effect free and deterministic based on its inputs), 
and both its inputs and outputs must be ``numpy`` arrays. These requirements are the
same as for the ``jax.pure_callback`` function.

[...]
"""

Note: edited for clarity and added an example docstring.

f0uriest commented 1 year ago

I'd also be very interested in helping with this. I've done a similar thing using pure_callback and custom_jvp but I haven't been able to get it to support both forward and reverse mode.

If I define a custom_jvp for the pure_callback, forward mode stuff like jacfwd works fine. However, if I try to use jacrev I usually get an error JaxStackTraceBeforeTransformation: ValueError: Pure callbacks do not support transpose. Please use `jax.custom_vjp` to use callbacks while taking gradients.

Using custom_vjp allows reverse mode to work, but since a function can't have both custom_jvp and custom_vjp defined (#3996), we're limited to one or the other.

ASEM000 commented 1 year ago

@mbmccoy I added a new functionality, define_fdjvp here, which satisfies most of the specs you defined. One note, forward, central and backward are replaced by offsets, so for $\sum A_i*f(x+a_i)/stepsize$, you only provide offsets=jnp.array([ai, ... ]) and $A_i$ will be calculated based on the other configs.