Open aseyboldt opened 1 month ago
Interesting, I would have initially assumed it would have errored too. Here's an example that gets to the root of the problem (in both senses of the word)
import jax.numpy as jnp
from jax import lax
import jax
def _bisection_search(func, *, lower, upper, tol: float, max_iter: int):
def cond_fn(state):
lower, upper, iterations = state
return jnp.logical_and((upper - lower) > 2 * tol, iterations < max_iter)
def body_fn(state):
lower, upper, iterations = state
midpoint = (lower + upper) / 2
sign = jnp.sign(func(midpoint))
lower = jnp.where(sign == 1, lower, midpoint)
upper = jnp.where(sign == 1, midpoint, upper)
return lower, upper, iterations + 1
init_state = (lower, upper, 0)
lower, upper, iterations = lax.while_loop(cond_fn, body_fn, init_state)
root = (lower + upper) / 2
return root, iterations
def get_root(x):
return _bisection_search(
func=lambda arr: arr + x,
lower=-10,
upper=10,
tol =1e-7,
max_iter=100,
)[0]
Note the gradient is actually zero everywhere it is defined, because the result is a stepwise function. We can see that visually by plotting it for a very small region
x = jnp.linspace(-1e-6, 1e-6, 1000)
roots = jax.vmap(get_root)(x)
import matplotlib.pyplot as plt
plt.plot(x, roots)
Regardless, I think that an error when differentiating through _bisection_search
might be better than returning a gradient of zero, because it's likely a mistake. I presume this could be done with jax.custom_jvp
?
Very nice package :-)
While playing around with different optimization objectives I ran into an autodiff issue. The following always returns exactly zeros, which I think isn't correct. This might be related to the bisection search (I think that's what's used here?) but if the while_loop in there is a problem I would have expected an error, not an incorrect result.