Open Justin-Tan opened 9 months ago
Okay, many things to respond to here!
Speed
With respect to the speed, for your JAX code are you:
block_until_ready
?In practice this means writing things out something like:
@jax.jit
def run(y, b):
sol = optax.root_find(vmap(fn), solver, y, b)
return sol.value
run(y, b) # compile
times = timeit.repeat(lambda: jax.block_until_ready(run(y, b)), number=1, repeat=10)
print(min(times))
Recalculating Jacobians
You commented on calculating the Jacobian afresh every iteration. If using the typical Newton algorithm then this is expected (desired) behaviour. But if you're saying that you'd prefer to use a quasi-Newton algorithm like the chord method (that computes the Jacobian once at the initial point and then re-uses it), then there is optx.Chord
as well.
Analytical Jacobians
You commented on supplying an analytical Jacobian. This isn't necessary, as the analytical Jacobian is actually already derived from fn
automatically using autodifferentiation. Unless the autodiff does something surprisingly inefficient, then providing one manually wouldn't meaningfully improve things there.
Custom Jacobians
If despite everything you really do want to provide a custom Jacobian, then this can be done using jax.custom_jvp
. By wrapping your fn
in a jax.custom_vjp
, then you can override how JAX calculates autoderivatives of your code. (And this will then be picked up by the autodiff used by Optimistix to calculate the Jacobian.)
Does the above help?
Hi devs, looks like a really nice library. I've been looking for a Jax-native root finding method that supports
vmap
for some time. Currently I am using an external call toscipy.optimize.root
together with the multiprocessing library, which is quite slow.The runtime for root finding using the Newton method in this library is slower than the above method though - I suspect this is because the Jacobian needs to be calculated at each iteration. Is there a way for the user to supply an analytic Jacobian? Or could you point me in the right direction to implement this feature?
For reference, this is my MWE in case I am not doing things efficiently: