patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
267 stars 12 forks source link

Including user-defined Jacobian #17

Open Justin-Tan opened 9 months ago

Justin-Tan commented 9 months ago

Hi devs, looks like a really nice library. I've been looking for a Jax-native root finding method that supports vmap for some time. Currently I am using an external call to scipy.optimize.root together with the multiprocessing library, which is quite slow.

The runtime for root finding using the Newton method in this library is slower than the above method though - I suspect this is because the Jacobian needs to be calculated at each iteration. Is there a way for the user to supply an analytic Jacobian? Or could you point me in the right direction to implement this feature?

For reference, this is my MWE in case I am not doing things efficiently:

from jax import jit, jacfwd, vmap, random
import optimistix as optx

def fn(y, b):
    return (y-b)**2

M = 1024
key = random.PRNGKey(42)
key, key_ = random.split(key, 2)

y = random.normal(key, (M,))
b = random.normal(key_, (M,))
sol = optx.root_find(vmap(fn), solver, y, b)
patrick-kidger commented 9 months ago

Okay, many things to respond to here!

Speed

With respect to the speed, for your JAX code are you:

In practice this means writing things out something like:

@jax.jit
def run(y, b):
    sol = optax.root_find(vmap(fn), solver, y, b)
    return sol.value

run(y, b)  # compile
times = timeit.repeat(lambda: jax.block_until_ready(run(y, b)), number=1, repeat=10)
print(min(times))

Recalculating Jacobians

You commented on calculating the Jacobian afresh every iteration. If using the typical Newton algorithm then this is expected (desired) behaviour. But if you're saying that you'd prefer to use a quasi-Newton algorithm like the chord method (that computes the Jacobian once at the initial point and then re-uses it), then there is optx.Chord as well.

Analytical Jacobians

You commented on supplying an analytical Jacobian. This isn't necessary, as the analytical Jacobian is actually already derived from fn automatically using autodifferentiation. Unless the autodiff does something surprisingly inefficient, then providing one manually wouldn't meaningfully improve things there.

Custom Jacobians

If despite everything you really do want to provide a custom Jacobian, then this can be done using jax.custom_jvp. By wrapping your fn in a jax.custom_vjp, then you can override how JAX calculates autoderivatives of your code. (And this will then be picked up by the autodiff used by Optimistix to calculate the Jacobian.)

Does the above help?