google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
927 stars 64 forks source link

Auto check `jit=True` if solver is `vmapped`? #242

Open joeryjoery opened 2 years ago

joeryjoery commented 2 years ago

I'm not sure whether this is possible, but I found out quite awkwardly that vmapping solvers strictly requires the optimized function to be staged.

For example, this code works fine:

import jax
import jax.numpy as jnp
import jaxopt

def fun(x, y):
    return jnp.square(x - y) + y

solver = jaxopt.LBFGS(fun=fun, maxiter=5, jit=True)

params, state = solver.run(jax.device_put(0.1), y=jax.device_put(2.1))

batch_params, batch_state = jax.vmap(lambda a, k: solver.run(a, **k), in_axes=(0, None))(
    jnp.repeat(0.1, 12), {'y': jax.device_put(2.1)})

# No errors raised

But, this code produces a concretization error due to jit=False in the solver.

import jax
import jax.numpy as jnp
import jaxopt

def fun(x, y):
    return jnp.square(x - y) + y

solver = jaxopt.LBFGS(fun=fun, maxiter=5, jit=False)

params, state = solver.run(jax.device_put(0.1), y=jax.device_put(2.1))  # Works fine

batch_params, batch_state = jax.vmap(lambda a, k: solver.run(a, **k), in_axes=(0, None))(
    jnp.repeat(0.1, 12), {'y': jax.device_put(2.1)})  # ConcretizationTypeError due to vmap

The exception is raised due to the cond-fun in the while-loop, so upon inspection I completely understand why this happens. The argument specification to the solver though kind of made me believe that I had the option to jit or not to jit, but with a vmap you have to jit...

Could this perhaps be inferred automatically? Or is this unfortunately just a sharp-edge of Jax, in which case I believe it would be useful to make a note of this in the documentation for the jit argument in the Solver classes.

mblondel commented 2 years ago

Hi @joeryjoery. Indeed, jit=True is necessary for a solver to be "vmappable". We can definitely improve the documentation about this but I am not sure if it's possible to detect this and raise an error message. @froystig will definitely know more.