google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
922 stars 64 forks source link

Wrong failure diagnostic print outs from `ZoomLineSearch` under `vmap` #555

Closed tare closed 8 months ago

tare commented 10 months ago

Environment

% pip list|grep jax   
jax                       0.4.20
jaxlib                    0.4.20
jaxopt                    0.8.2

% python --version
Python 3.10.11

Description

ZoomLineSearch under vmap ends up calling failure_diagnostic() even when safe_stepsize > 0. as shown here. This can result in a lot of print outs, and I didn't see a way to disable the failure diagnostic print outs given the current implementation. I think the relevant commit is https://github.com/google/jaxopt/commit/614dc7bf769628eee6f72e636cb608c0f6678596. Below, you will find minimum reproducible examples.

The following code

import jax.numpy as jnp
from jax import jit, vmap
from jaxopt import LBFGS
from jax.lax import map

def solve(x, y):
    solver = LBFGS(lambda x, y: jnp.square(y-x), linesearch="zoom")
    x, _ = solver.run(x, y=y)
    return x

x_init = jnp.zeros(())
ys = jnp.arange(1)

vmap(solve, in_axes=(None, 0))(jnp.zeros(()), ys)

gives the following warnings

WARNING: jaxopt.ZoomLineSearch: Linesearch failed, no stepsize satisfying sufficient decrease found.
INFO: jaxopt.ZoomLineSearch: Iter: 1, Stepsize: 1.0, Decrease error: -0.0, Curvature error: 0.0
WARNING: jaxopt.ZoomLineSearch: The linesearch failed because the provided direction is not a descent direction. The slope (=-0.0) at stepsize=0 should be negative
WARNING: jaxopt.ZoomLineSearch: Consider augmenting the maximal number of linesearch iterations.
WARNING: jaxopt.ZoomLineSearch: Computed stepsize (=1.0) is below machine precision (=1.1920928955078125e-07), consider passing to higher precision like x64, using jax.config.update('jax_enable_x64', True).
WARNING: jaxopt.ZoomLineSearch: Very large absolute slope at stepsize=0. (|slope|=0.0). The objective is badly conditioned. Consider reparameterizing objective (e.g., normalizing parameters) or finding a better guess for the initial parameters for the solver.
WARNING: jaxopt.ZoomLineSearch: Cannot even make a step without getting Inf or Nan. The linesearch won't make a step and the optimizer is stuck.
WARNING: jaxopt.ZoomLineSearch: Making an unsafe step, not decreasing enough the objective. Convergence of the solver is compromised as it does not reduce values.

Whereas, the following code does not produce any warnings

import jax.numpy as jnp
from jax import jit, vmap
from jaxopt import LBFGS
from jax.lax import map

def solve(x, y):
    solver = LBFGS(lambda x, y: jnp.square(y-x), linesearch="zoom")
    x, _ = solver.run(x, y=y)
    return x

x_init = jnp.zeros(())
ys = jnp.arange(1)

res = map(jit(lambda y: solve(jnp.zeros(()), y)), ys)

Here is a minimal reproducible example illustrating the issue with jax.debug.print, cond, and vmap; the following code

import jax.numpy as jnp
from jax import vmap, jit
from jax.lax import cond
import jax.debug

def test(x):
    def true_fun(x):
        pass
    def false_fun(x):
        jax.debug.print("{}", x)
    cond(x < 3, true_fun, false_fun, x)

print("map and jit")
map(jit(test), jnp.arange(5))
print("vmap")
vmap(test)(jnp.arange(5))

gives the following output

map and jit
3
4
vmap
0
1
2
3
4
vroulet commented 10 months ago

Hello @tare, Thanks for pointing this out. vmap evaluates both branches of a cond (not the case without vmap), see https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.cond.html. I'm not sure how we could then have failure diagnostics under vmap. At least I have patched #544 for zoom not to display failure diagnostics unless verbose is set to True. That will avoid unncessary prints.

tare commented 10 months ago

Thanks for the quick reply and pointing out https://github.com/google/jaxopt/pull/544! I hope that PR gets merged soon.

vroulet commented 8 months ago

Closing as #544 has been merged.