Open h3jia opened 1 year ago
Hello @hjia,
Thanks for reporting this! The issue is that lower
and upper
need to be static arguments for making the implicit differentiation compatible with jit (in the solve_2
and solve_3
cases). So a possible workaround is presented below. We could also change the signature of the Bisection method to let lower
and upper
be taken as parameters. But before changing this, could you give us a bit more context on what you want to do? A priori grad already traces the function such that you would not need to jit it before taking the gradient.
import numpy as np
import jax.numpy as jnp
import numpy as np
import jax
from jaxopt import Bisection
from functools import partial
@jax.jit
def _xy_c(r, phi, spin, theta_o):
lam = spin + r / spin * (r - (2 * (r**2 - 2 * r + spin**2)) / (r - 1))
eta = r**3 / spin**2 *((4 * (r**2 - 2 * r + spin**2)) / (r - 1)**2 - r)
alpha = -lam / jnp.sin(theta_o)
beta = eta + spin**2 * jnp.cos(theta_o)**2 - lam**2 * jnp.tan(theta_o)**(-2)
beta = jnp.sign(beta) * jnp.sqrt(jnp.abs(beta))
return alpha, beta
@jax.jit
def _r_c_solve(r, phi, spin, theta_o):
alpha, beta = _xy_c(r, phi, spin, theta_o)
return (jnp.arctan2(beta, alpha) * 180. / jnp.pi + 90) % 360 - 90 - phi * 180. / jnp.pi
def r_c_solve(phi, spin, theta_o):
phi = phi * jnp.pi / 180.
theta_o = theta_o * jnp.pi / 180.
theta_o = jnp.clip(theta_o, 1e-5, jnp.pi - 1e-5)
r_m = 2 * (1 + np.cos(2 / 3 * np.arccos(-spin)))
r_p = 2 * (1 + np.cos(2 / 3 * np.arccos(spin)))
r_0 = r_m - 0.0001 * (r_p - r_m)
r_1 = r_p + 0.0001 * (r_p - r_m)
return Bisection(optimality_fun=_r_c_solve, lower=r_0, upper=r_1,
check_bracket=False).run(phi=phi, spin=spin, theta_o=theta_o).params
g_r_c_solve_0 = jax.grad(r_c_solve)
g_r_c_solve_0(10., spin=0.9375, theta_o=163)
g_r_c_solve_1 = jax.jit(jax.grad(r_c_solve), static_argnames='spin')
g_r_c_solve_1(10., spin=0.9375, theta_o=163)
g_r_c_solve_2 = jax.grad(jax.jit(r_c_solve, static_argnames='spin'))
g_r_c_solve_2(10., spin=0.9375, theta_o=163)
g_r_c_solve_3 = jax.jit(jax.grad(jax.jit(r_c_solve, static_argnames='spin')))
g_r_c_solve_2(10., spin=0.9375, theta_o=163)
We could want to make the Bisection method differentiable with respect to its lower and upper values. The following code fails for example. But so it would be nice to have a use case for us to rethink the implementation of Bisection.
import numpy as np
import jax.numpy as jnp
import numpy as np
import jax
from jaxopt import Bisection
from functools import partial
@jax.jit
def _xy_c(r, phi, spin, theta_o):
lam = spin + r / spin * (r - (2 * (r**2 - 2 * r + spin**2)) / (r - 1))
eta = r**3 / spin**2 *((4 * (r**2 - 2 * r + spin**2)) / (r - 1)**2 - r)
alpha = -lam / jnp.sin(theta_o)
beta = eta + spin**2 * jnp.cos(theta_o)**2 - lam**2 * jnp.tan(theta_o)**(-2)
beta = jnp.sign(beta) * jnp.sqrt(jnp.abs(beta))
return alpha, beta
@jax.jit
def _r_c_solve(r, phi, spin, theta_o):
alpha, beta = _xy_c(r, phi, spin, theta_o)
return (jnp.arctan2(beta, alpha) * 180. / jnp.pi + 90) % 360 - 90 - phi * 180. / jnp.pi
def r_c_solve(spin, phi, theta_o):
phi = phi * jnp.pi / 180.
theta_o = theta_o * jnp.pi / 180.
theta_o = jnp.clip(theta_o, 1e-5, jnp.pi - 1e-5)
r_m = 2 * (1 + jnp.cos(2 / 3 * jnp.arccos(-spin)))
r_p = 2 * (1 + jnp.cos(2 / 3 * jnp.arccos(spin)))
r_0 = r_m - 0.0001 * (r_p - r_m)
r_1 = r_p + 0.0001 * (r_p - r_m)
return Bisection(optimality_fun=_r_c_solve, lower=r_0, upper=r_1,
check_bracket=False).run(phi=phi, spin=spin, theta_o=theta_o).params
g_r_c_solve_0 = jax.grad(r_c_solve)
g_r_c_solve_0(0.9375, 10., 63)
We could also change the signature of the Bisection method to let lower and upper be taken as parameters
I don't think we can. lower
and upper
are arguments of the algorithm, not of the objective, which means they're not part of the optimality conditions. So, we can't use implicit differentiation. Unrolling will likely not work either due to discontinuous operations.
Not sure if it's applicable here but an alternative would be to use stop_gradient
(see example here).
If you agree with me, we can relabel this issue as documentation. Adding a short paragraph on this would be helpful.
Yes, I see the issue. This would be good to know. Thanks!
Sorry for the delayed reply. In the example above making spin
static in jit is not a good idea for me, since I do need this to work at many different spin
's.
The issue here does not really prevent me from computing what I want, but it does make my code ugly. I need to have a jitted version and a unjitted version for each function, rather than just jit everything at definition.
I'm not really an expert on jax.jit
, but technically is it possible to get some pointer towards foo
from jax.jit(foo)
? If yes, then I think there should be a way to make jax.jit(jax.grad(jax.jit(r_c_solve)))
work similar to jax.jit(jax.grad(r_c_solve))
, i.e. just let it use the underlying unjitted function instead of the jitted one.
@mblondel not sure if I understand your comment regarding stop_gradient
. Nothing changes if I do lower=jax.lax.stop_gradient(r_0), upper=jax.lax.stop_gradient(r_1)
in my snippet.
Hello, I'm trying to work with the following snippet,
I think usually it does not matter whether one jit the intermediate functions, i.e.
jit(A(B))
is the same asjit(A(jit(B)))
. However, I find this no longer the case whenjaxopt.Bisection
is involved. For example, the followingg_r_c_solve_0
andg_r_c_solve_1
works well,But
g_r_c_solve_2
andg_r_c_solve_3
will give me anUnexpectedTracerError
,with the full error message below,
It turns out that I cannot take gradients of already jitted functions. Is it possible to fix this issue?
FYI, I'm using
jax=0.4.13
,jaxlib=0.4.13
,jaxopt=0.7
andpython=3.11.4
.