patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
265 stars 12 forks source link

grad of vmap of function which wraps an optax solver occasionally fails #39

Closed TomClarkMassSpec closed 5 months ago

TomClarkMassSpec commented 5 months ago

Hi, I previously had the optx newton root finding algorithm in operation which used a jnp.where to set a default value when the root_finder couldn't find a solution. It worked to insert the default value but the program would fail to find the gradient when default value was implemented. I ended up moving to the optx minimizer wrapper for a optax solver to minimize a func in place of a root finding operation and this works very nicely as it handles the more extreme slopes that occur in my functions. But then params outside were changed such that two long lumbers equalled the negative of each other with x64 precision. The point is not the need to buy a lottery ticket but that I need a way to make grad work when the solver cannot find a solution.

Specifics: I use vmap to fill out the elements of a 1D array, by calling a function by vmap for each elemen of array. That function includes the following code:

new code returns a value = y0 and grad = crash when solution does not exist

optimizer_acc = optx.OptaxMinimiser(optax.adabelief(learning_rate=1e-2), rtol=1e-8, atol=1e-8)
y0 = (jnp.array(0.01))
sol = optx.minimise(fn=time_root_from_distance, solver = optimizer_acc, y0 = y0, args = instance_of_acceleration, options=dict(lower=0.),max_steps=10000, throw=False)

old code returns default value but crashes when grad requested:

solver_root = optx.Newton(rtol=1e-5, atol=1e-4)
    y0 = (jnp.array(0.01))
    sol = optx.root_find(fn=time_root_from_distance, solver = solver_root, y0 = y0, args = instance_of_acceleration, options=dict(lower=0.),max_steps=10000, throw=False)    
    Thv = jnp.where(sol.result == optx.RESULTS.successful, sol.value, 9999.)

Both work when they can find a solution. But when a solution does not exist, cannot be found, jax.grad(Objective) fails.

Not sure if this question is misplaced but any suggestions on an approach to return a grad not just a value when solution is absent would be appreciated.

Thanks, Tom

patrick-kidger commented 5 months ago

This looks like a variety of https://docs.kidger.site/equinox/api/debug/#common-sources-of-nans, where doing something like jnp.where(if_success, might_fail(...), alternative_if_fails(...)) won't have an appropriate jnp.where on the backward pass.

You can probably fix this by wrapping the whole thing in a jax.custom_vjp (or an equinox.filter_custom_vjp if you prefer its syntax) so that the backward pass also has an appropriate jnp.where on it.

TomClarkMassSpec commented 5 months ago

Agreed. I abandoned this approach and coded from scratch.