google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
936 stars 66 forks source link

document that lower and upper bounds in Bisection need stop_gradient for differentiation #472

Open h3jia opened 1 year ago

h3jia commented 1 year ago

Hello, I'm trying to work with the following snippet,

import numpy as np
import jax.numpy as jnp
import jax
from jaxopt import Bisection

@jax.jit
def _xy_c(r, phi, spin, theta_o):
    lam = spin + r / spin * (r - (2 * (r**2 - 2 * r + spin**2)) / (r - 1))
    eta = r**3 / spin**2 *((4 * (r**2 - 2 * r + spin**2)) / (r - 1)**2 - r)
    alpha = -lam / jnp.sin(theta_o)
    beta = eta + spin**2 * jnp.cos(theta_o)**2 - lam**2 * jnp.tan(theta_o)**(-2)
    beta = jnp.sign(beta) * jnp.sqrt(jnp.abs(beta))
    return alpha, beta

@jax.jit
def _r_c_solve(r, phi, spin, theta_o):
    alpha, beta = _xy_c(r, phi, spin, theta_o)
    return (jnp.arctan2(beta, alpha) * 180. / jnp.pi + 90) % 360 - 90 - phi * 180. / jnp.pi

def r_c_solve(phi, spin, theta_o):
    phi = phi * jnp.pi / 180.
    theta_o = theta_o * jnp.pi / 180.
    theta_o = jnp.clip(theta_o, 1e-5, jnp.pi - 1e-5)
    r_m = 2 * (1 + jnp.cos(2 / 3 * jnp.arccos(-spin)))
    r_p = 2 * (1 + jnp.cos(2 / 3 * jnp.arccos(spin)))
    r_0 = r_m - 0.0001 * (r_p - r_m)
    r_1 = r_p + 0.0001 * (r_p - r_m)
    return Bisection(optimality_fun=_r_c_solve, lower=r_0, upper=r_1,
                     check_bracket=False).run(phi=phi, spin=spin, theta_o=theta_o).params

I think usually it does not matter whether one jit the intermediate functions, i.e. jit(A(B)) is the same as jit(A(jit(B))). However, I find this no longer the case when jaxopt.Bisection is involved. For example, the following g_r_c_solve_0 and g_r_c_solve_1 works well,

g_r_c_solve_0 = jax.grad(r_c_solve)
%time g_r_c_solve_0(10., 0.9375, 163)

g_r_c_solve_1 = jax.jit(jax.grad(r_c_solve))
%time g_r_c_solve_1(10., 0.9375, 163)

But g_r_c_solve_2 and g_r_c_solve_3 will give me an UnexpectedTracerError,

g_r_c_solve_2 = jax.grad(jax.jit(r_c_solve))
%time g_r_c_solve_2(10., 0.9375, 163)

g_r_c_solve_3 = jax.jit(jax.grad(jax.jit(r_c_solve)))
%time g_r_c_solve_3(10., 0.9375, 163)

with the full error message below,

---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File <frozen runpy>:198, in _run_module_as_main()

File <frozen runpy>:88, in _run_code()

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/ipykernel_launcher.py:17
     15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/traitlets/config/application.py:1043, in launch_instance()
   1042 app.initialize(argv)
-> 1043 app.start()

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/ipykernel/kernelapp.py:728, in start()
    727 try:
--> 728     self.io_loop.start()
    729 except KeyboardInterrupt:

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/tornado/platform/asyncio.py:195, in start()
    194 def start(self) -> None:
--> 195     self.asyncio_loop.run_forever()

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/asyncio/base_events.py:607, in run_forever()
    606 while True:
--> 607     self._run_once()
    608     if self._stopping:

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/asyncio/base_events.py:1922, in _run_once()
   1921     else:
-> 1922         handle._run()
   1923 handle = None

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/asyncio/events.py:80, in _run()
     79 try:
---> 80     self._context.run(self._callback, *self._args)
     81 except (SystemExit, KeyboardInterrupt):

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/ipykernel/kernelbase.py:516, in dispatch_queue()
    515 try:
--> 516     await self.process_one()
    517 except Exception:

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/ipykernel/kernelbase.py:505, in process_one()
    504         return None
--> 505 await dispatch(*args)

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/ipykernel/kernelbase.py:412, in dispatch_shell()
    411     if inspect.isawaitable(result):
--> 412         await result
    413 except Exception:

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/ipykernel/kernelbase.py:740, in execute_request()
    739 if inspect.isawaitable(reply_content):
--> 740     reply_content = await reply_content
    742 # Flush output before sending the reply.

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/ipykernel/ipkernel.py:422, in do_execute()
    421 if with_cell_id:
--> 422     res = shell.run_cell(
    423         code,
    424         store_history=store_history,
    425         silent=silent,
    426         cell_id=cell_id,
    427     )
    428 else:

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/ipykernel/zmqshell.py:540, in run_cell()
    539 self._last_traceback = None
--> 540 return super().run_cell(*args, **kwargs)

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3009, in run_cell()
   3008 try:
-> 3009     result = self._run_cell(
   3010         raw_cell, store_history, silent, shell_futures, cell_id
   3011     )
   3012 finally:

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3064, in _run_cell()
   3063 try:
-> 3064     result = runner(coro)
   3065 except BaseException as e:

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner()
    128 try:
--> 129     coro.send(None)
    130 except StopIteration as exc:

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3269, in run_cell_async()
   3266 interactivity = "none" if silent else self.ast_node_interactivity
-> 3269 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   3270        interactivity=interactivity, compiler=compiler, result=result)
   3272 self.last_execution_succeeded = not has_raised

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3448, in run_ast_nodes()
   3447     asy = compare(code)
-> 3448 if await self.run_code(code, result, async_=asy):
   3449     return True

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3508, in run_code()
   3507     else:
-> 3508         exec(code_obj, self.user_global_ns, self.user_ns)
   3509 finally:
   3510     # Reset our crash handler in place

Cell In[10], line 1
----> 1 get_ipython().run_line_magic('time', 'g_r_c_solve_2(10., 0.9375, 163)')

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/IPython/core/interactiveshell.py:2417, in run_line_magic()
   2416 with self.builtin_trap:
-> 2417     result = fn(*args, **kwargs)
   2419 # The code below prevents the output from being displayed
   2420 # when using magics with decodator @output_can_be_silenced
   2421 # when the last Python token in the expression is a ';'.

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/IPython/core/magics/execution.py:1317, in time()
   1316 try:
-> 1317     out = eval(code, glob, local_ns)
   1318 except:

File <timed eval>:1

Cell In[2], line 24, in r_c_solve()
     22 r_1 = r_p + 0.0001 * (r_p - r_m)
     23 return Bisection(optimality_fun=_r_c_solve, lower=r_0, upper=r_1,
---> 24                  check_bracket=False).run(phi=phi, spin=spin, theta_o=theta_o).params

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/jaxopt/_src/bisection.py:158, in run()
    153 def run(self,
    154         init_params: Optional[Any] = None,
    155         *args,
    156         **kwargs) -> base.OptStep:
    157   # We override run in order to set init_params=None by default.
--> 158   return super().run(init_params, *args, **kwargs)

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/jaxopt/_src/base.py:354, in run()
    352   run = decorator(run)
--> 354 return run(init_params, *args, **kwargs)

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/jaxopt/_src/implicit_diff.py:251, in wrapped_solver_fun()
    250 keys, vals = list(kwargs.keys()), list(kwargs.values())
--> 251 return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)

JaxStackTraceBeforeTransformation: jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was r_c_solve at /var/folders/s_/0jqkjq792dj7g4ddvdh4j5_r0000gn/T/ipykernel_89057/3431507493.py:15 traced for jit.
------------------------------
The leaked intermediate value was created on line /var/folders/s_/0jqkjq792dj7g4ddvdh4j5_r0000gn/T/ipykernel_89057/3431507493.py:21:10 (r_c_solve). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
<frozen runpy>:198:11 (_run_module_as_main)
<frozen runpy>:88:4 (_run_code)
/var/folders/s_/0jqkjq792dj7g4ddvdh4j5_r0000gn/T/ipykernel_89057/1728866946.py:1 (<module>)
<timed eval>:1 (<module>)
/var/folders/s_/0jqkjq792dj7g4ddvdh4j5_r0000gn/T/ipykernel_89057/3431507493.py:21:10 (r_c_solve)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

UnexpectedTracerError                     Traceback (most recent call last)
File <timed eval>:1

    [... skipping hidden 33 frame]

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/jaxopt/_src/implicit_diff.py:210, in _custom_root.<locals>.make_custom_vjp_solver_fun.<locals>.solver_fun_fwd(*flat_args)
    209 def solver_fun_fwd(*flat_args):
--> 210   res = solver_fun_flat(*flat_args)
    211   return res, (res, flat_args)

    [... skipping hidden 4 frame]

File ~/miniconda3/envs/hejia@mac-2/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:1835, in DynamicJaxprTrace.getvar(self, tracer)
   1833 var = self.frame.tracer_to_var.get(id(tracer))
   1834 if var is None:
-> 1835   raise core.escaped_tracer_error(tracer)
   1836 return var

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was r_c_solve at /var/folders/s_/0jqkjq792dj7g4ddvdh4j5_r0000gn/T/ipykernel_89057/3431507493.py:15 traced for jit.
------------------------------
The leaked intermediate value was created on line /var/folders/s_/0jqkjq792dj7g4ddvdh4j5_r0000gn/T/ipykernel_89057/3431507493.py:21:10 (r_c_solve). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
<frozen runpy>:198:11 (_run_module_as_main)
<frozen runpy>:88:4 (_run_code)
/var/folders/s_/0jqkjq792dj7g4ddvdh4j5_r0000gn/T/ipykernel_89057/1728866946.py:1 (<module>)
<timed eval>:1 (<module>)
/var/folders/s_/0jqkjq792dj7g4ddvdh4j5_r0000gn/T/ipykernel_89057/3431507493.py:21:10 (r_c_solve)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

It turns out that I cannot take gradients of already jitted functions. Is it possible to fix this issue?

FYI, I'm using jax=0.4.13, jaxlib=0.4.13, jaxopt=0.7 and python=3.11.4.

vroulet commented 1 year ago

Hello @hjia,

Thanks for reporting this! The issue is that lower and upper need to be static arguments for making the implicit differentiation compatible with jit (in the solve_2 and solve_3 cases). So a possible workaround is presented below. We could also change the signature of the Bisection method to let lower and upper be taken as parameters. But before changing this, could you give us a bit more context on what you want to do? A priori grad already traces the function such that you would not need to jit it before taking the gradient.

import numpy as np
import jax.numpy as jnp
import numpy as np
import jax
from jaxopt import Bisection
from functools import partial

@jax.jit
def _xy_c(r, phi, spin, theta_o):
    lam = spin + r / spin * (r - (2 * (r**2 - 2 * r + spin**2)) / (r - 1))
    eta = r**3 / spin**2 *((4 * (r**2 - 2 * r + spin**2)) / (r - 1)**2 - r)
    alpha = -lam / jnp.sin(theta_o)
    beta = eta + spin**2 * jnp.cos(theta_o)**2 - lam**2 * jnp.tan(theta_o)**(-2)
    beta = jnp.sign(beta) * jnp.sqrt(jnp.abs(beta))
    return alpha, beta

@jax.jit
def _r_c_solve(r, phi, spin, theta_o):
    alpha, beta = _xy_c(r, phi, spin, theta_o)
    return (jnp.arctan2(beta, alpha) * 180. / jnp.pi + 90) % 360 - 90 - phi * 180. / jnp.pi

def r_c_solve(phi, spin, theta_o):
    phi = phi * jnp.pi / 180.
    theta_o = theta_o * jnp.pi / 180.
    theta_o = jnp.clip(theta_o, 1e-5, jnp.pi - 1e-5)
    r_m = 2 * (1 + np.cos(2 / 3 * np.arccos(-spin)))
    r_p = 2 * (1 + np.cos(2 / 3 * np.arccos(spin)))
    r_0 = r_m - 0.0001 * (r_p - r_m)
    r_1 = r_p + 0.0001 * (r_p - r_m)
    return Bisection(optimality_fun=_r_c_solve, lower=r_0, upper=r_1,
                     check_bracket=False).run(phi=phi, spin=spin, theta_o=theta_o).params

g_r_c_solve_0 = jax.grad(r_c_solve)
g_r_c_solve_0(10., spin=0.9375, theta_o=163)

g_r_c_solve_1 = jax.jit(jax.grad(r_c_solve), static_argnames='spin')
g_r_c_solve_1(10., spin=0.9375, theta_o=163)

g_r_c_solve_2 = jax.grad(jax.jit(r_c_solve, static_argnames='spin'))
g_r_c_solve_2(10., spin=0.9375, theta_o=163)

g_r_c_solve_3 = jax.jit(jax.grad(jax.jit(r_c_solve, static_argnames='spin')))
g_r_c_solve_2(10., spin=0.9375, theta_o=163)
vroulet commented 1 year ago

We could want to make the Bisection method differentiable with respect to its lower and upper values. The following code fails for example. But so it would be nice to have a use case for us to rethink the implementation of Bisection.

import numpy as np
import jax.numpy as jnp
import numpy as np
import jax
from jaxopt import Bisection
from functools import partial

@jax.jit
def _xy_c(r, phi, spin, theta_o):
    lam = spin + r / spin * (r - (2 * (r**2 - 2 * r + spin**2)) / (r - 1))
    eta = r**3 / spin**2 *((4 * (r**2 - 2 * r + spin**2)) / (r - 1)**2 - r)
    alpha = -lam / jnp.sin(theta_o)
    beta = eta + spin**2 * jnp.cos(theta_o)**2 - lam**2 * jnp.tan(theta_o)**(-2)
    beta = jnp.sign(beta) * jnp.sqrt(jnp.abs(beta))
    return alpha, beta

@jax.jit
def _r_c_solve(r, phi, spin, theta_o):
    alpha, beta = _xy_c(r, phi, spin, theta_o)
    return (jnp.arctan2(beta, alpha) * 180. / jnp.pi + 90) % 360 - 90 - phi * 180. / jnp.pi

def r_c_solve(spin, phi, theta_o):
    phi = phi * jnp.pi / 180.
    theta_o = theta_o * jnp.pi / 180.
    theta_o = jnp.clip(theta_o, 1e-5, jnp.pi - 1e-5)
    r_m = 2 * (1 + jnp.cos(2 / 3 * jnp.arccos(-spin)))
    r_p = 2 * (1 + jnp.cos(2 / 3 * jnp.arccos(spin)))
    r_0 = r_m - 0.0001 * (r_p - r_m)
    r_1 = r_p + 0.0001 * (r_p - r_m)
    return Bisection(optimality_fun=_r_c_solve, lower=r_0, upper=r_1,
                     check_bracket=False).run(phi=phi, spin=spin, theta_o=theta_o).params

g_r_c_solve_0 = jax.grad(r_c_solve)
g_r_c_solve_0(0.9375, 10., 63)
mblondel commented 1 year ago

We could also change the signature of the Bisection method to let lower and upper be taken as parameters

I don't think we can. lower and upper are arguments of the algorithm, not of the objective, which means they're not part of the optimality conditions. So, we can't use implicit differentiation. Unrolling will likely not work either due to discontinuous operations.

Not sure if it's applicable here but an alternative would be to use stop_gradient (see example here).

mblondel commented 1 year ago

If you agree with me, we can relabel this issue as documentation. Adding a short paragraph on this would be helpful.

vroulet commented 1 year ago

Yes, I see the issue. This would be good to know. Thanks!

h3jia commented 1 year ago

Sorry for the delayed reply. In the example above making spin static in jit is not a good idea for me, since I do need this to work at many different spin's.

The issue here does not really prevent me from computing what I want, but it does make my code ugly. I need to have a jitted version and a unjitted version for each function, rather than just jit everything at definition.

I'm not really an expert on jax.jit, but technically is it possible to get some pointer towards foo from jax.jit(foo)? If yes, then I think there should be a way to make jax.jit(jax.grad(jax.jit(r_c_solve))) work similar to jax.jit(jax.grad(r_c_solve)), i.e. just let it use the underlying unjitted function instead of the jitted one.

h3jia commented 1 year ago

@mblondel not sure if I understand your comment regarding stop_gradient. Nothing changes if I do lower=jax.lax.stop_gradient(r_0), upper=jax.lax.stop_gradient(r_1) in my snippet.