google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
918 stars 64 forks source link

Deadlock possibly caused by JaxOpt #387

Open NeilGirdhar opened 1 year ago

NeilGirdhar commented 1 year ago

I have a mysterious deadlock that is introduced when I replaced my gradient descent with JaxOpt. I reported it as a Jax bug, but could it be a JaxOpt bug? The MWE is about 100 lines, and can be reproduced using these instructions. I did comb through the JaxOpt source code, but it seems that all while loops are guarded by maximum iteration conditions. Is there something I could be missing?

NeilGirdhar commented 1 year ago

@zaccharieramzi Since you had a recent commit about LBFGS, would you kindly be able to let me know if my error could be a failed line search in LBFGS?

zaccharieramzi commented 1 year ago

Hi @NeilGirdhar : looking at your linked MWE, I am not sure where in your code you use LBFGS. To know whether a failed line search is happening when using LBFGS, you can always check the attribute of the state failed_linesearch, which should tell you whether line search has failed or not. You can also just enable stop_if_linesearch_fails when instantiating the solver.

However, I doubt this is a line search error: what I corrected was not something that created deadlocks, just longer optimization runs. How many iterations do you need to reach a freeze? Maybe a stupid question, but how do you know it's a freeze and not just the algorithm being super slow?

I saw in your older issue that disabling jit solves the problem so I couldn't really give you smart debugging advice unfortunately, I am really new to jax so it's difficult for me to help in that regard.

zaccharieramzi commented 1 year ago

Also pinging @mblondel who might know more on this since he developed LBFGS

NeilGirdhar commented 1 year ago

Hi @NeilGirdhar : looking at your linked MWE, I am not sure where in your code you use LBFGS.

Oh, my mistake, I switched it to GradientDescent for the time being.

stop_if_linesearch_fails

Thanks, I'll keep this in mind when I switch back to LBFGS.

How many iterations do you need to reach a freeze? Maybe a stupid question, but how do you know it's a freeze and not just the algorithm being super slow?

Five iterations of Jax code. I'm not sure how many iterations of the solver, but maxiter is set to 250. I have let it run for a very long time. The reason I don't think it's just slow is because that would mean that JaxOpt is orders of magnitude slower than my naively implemented gradient descent, and it wouldn't explain the problem disappearing when switching to 64 bits.

Something very strange is happening, either in Jax's XLA compiler, or in JaxOpt's loops.

I saw in your older issue that disabling jit solves the problem so I couldn't really give you smart debugging advice unfortunately, I am really new to jax so it's difficult for me to help in that regard.

No worries, I really appreciate the speedy reply. I've been struggling for a month now!

zaccharieramzi commented 1 year ago

Ok maybe you can try searching for the minimal setting of maxiter that makes this bug. Also it could still be that the algorithm is super slow in 32bits because it doesn't reach the tolerance for some floating point error reason, and does in 64bits (but that's just a wild guess).

But does your code work when using jaxopt's GradientDescent? (also pay attention to the fact that jaxopt's GradientDescent is by default accelerated)

NeilGirdhar commented 1 year ago

Ok maybe you can try searching for the minimal setting of maxiter that makes this bug.

I got it down to maxiter=11.

Also it could still be that the algorithm is super slow in 32bits because it doesn't reach the tolerance for some floating point error reason, and does in 64bits (but that's just a wild guess).

Interesting idea, but my tolerance is 0.001. Is that low enough to cause problems?

But does your code work when using jaxopt's GradientDescent? (also pay attention to the fact that jaxopt's GradientDescent is by default accelerated)

It doesn't work with GradientDescent.

(also pay attention to the fact that jaxopt's GradientDescent is by default accelerated)

Okay, but I'm running on the CPU only. Unless you mean jitted? If so, that's related to the bug since turning of the JIT prevents the bug.

zaccharieramzi commented 1 year ago

no no with this tolerance it should be fine.

Haaa ok it doesn't work with GradientDescent, I thought it did. Have you tried removing all "options" from the solver:

We might pin down the issue if you try with either of these (or both) turned off

NeilGirdhar commented 1 year ago

setting acceleration=False to get rid of the FISTA step

I've already set it to false.

taking a constant stepsize or the schedule of your choice (or maybe just reducing maxls the maximum number of line search iterations)

Lowering maxls from its default of 15 down to 5 does make the program get farther but it eventually freezes. At 2 is succeeds. At 4, it produces a bunch of nans and doesn't freeze.

Setting a constant step size of 1e-3 allows the program to complete normally.

NeilGirdhar commented 1 year ago

I guess one thing that would definitely cause this is if the minimization function f returning nan would cause JaxOpt's gradient descent to loop forever. I would consider that a bug in JaxOpt.

zaccharieramzi commented 1 year ago

It shouldn't because every while_loop has maxiter, but indeed you could test this in a very simple setup with a function that returns nan randomly. Unfortunately since this seems related to the very specific fista_linesearch implemented here, I am afraid I am not able to help much further at this stage. I guess @mblondel will have much more to say (maybe after the ICML deadline).

NeilGirdhar commented 1 year ago

@zaccharieramzi Okay, thanks so much for your kind help !

NeilGirdhar commented 1 year ago

@mblondel Would you happen to have some time to look at this blocking bug for me?

mblondel commented 1 year ago

Do you have a reproducing example in a single script that I can execute?

NeilGirdhar commented 1 year ago

@mblondel It is in a single script. (I linked it on the issue.) However, if you do the poetry installation, then it'll guarantee that you have the same environment too.

mblondel commented 1 year ago

Maybe I missed something but I don't see any main function in this file.

NeilGirdhar commented 1 year ago

@mblondel Oh, I'm sorry, cli is the entry point. You can run simply append cli() to the file to run it.

mblondel commented 1 year ago

Did you try running your program in float64 precision?

NeilGirdhar commented 1 year ago

@mblondel Yes, that works.

I appreciate your taking a look. I don't think running in 64-bit precision is a reasonable workaround since I don't think Jax should ever lock up like this. Something has gone seriously wrong either with Jax's compilation or in JaxOpt.

mblondel commented 1 year ago

The fact that it works in float64 suggests that there might a be a way to work around the stability issue by writing the objective function differently. I'm not if the bug comes from JAXopt, since as you said, our loops all have a maximum number of iterations.

NeilGirdhar commented 1 year ago

The fact that it works in float64 suggests that there might a be a way to work around the stability issue by writing the objective function differently.

How does a "stability issue" cause a deadlock?

I'm not if the bug comes from JAXopt,

Right, I'm not sure either. But the alternative is that it comes from Jax, which means some kind of compiler bug. Were you able to reproduce it?

NeilGirdhar commented 1 year ago

@mblondel Sorry to keep bugging you on this. I really want to use JaxOpt, but I'm blocked on this bug. Do you think this is likely to be a problem with Jax? Were you able to reproduce this bug?

mblondel commented 1 year ago

I can give it a try but I need a single script (as short as possible) that I can run and without additional dependencies.

NeilGirdhar commented 1 year ago

@mblondel Okay thanks, I'll start putting that together for you.

NeilGirdhar commented 1 year ago

No dependencies except Jax and JaxOpt.

from __future__ import annotations

from collections.abc import Callable
from functools import partial
from typing import NamedTuple

import jax.numpy as jnp
from jax import Array, custom_vjp, grad, jit, vjp, vmap
from jax.lax import dot
from jax.nn import softplus
from jax.tree_util import tree_map
from jaxopt import GradientDescent

class Weights(NamedTuple):
    b1: Array
    w1: Array
    b2: Array
    w2: Array

class AdamState(NamedTuple):
    count: Array
    mu: Weights
    nu: Weights

def update_moment(updates, moments, decay, order):
  return tree_map(lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments)

def update_moment_per_elem_norm(updates, moments, decay, order):
  def orderth_norm(g):
    if jnp.isrealobj(g):
      return g ** order
    else:
      half_order = order / 2
      # JAX generates different HLO for int and float `order`
      if half_order.is_integer():
        half_order = int(half_order)
      return jnp.square(g) ** half_order

  return tree_map(lambda g, t: (1 - decay) * orderth_norm(g) + decay * t, updates, moments)

def bias_correction(moment, decay, count):
  bias_correction_ = 1 - decay**count

  # Perform division in the original precision.
  return tree_map(lambda t: t / bias_correction_.astype(t.dtype), moment)

b1 = 0.9
b2 = 0.999
eps = 1e-8
learning_rate = 1e-2

class Adam(NamedTuple):
    b1: Array = jnp.asarray(0.9)
    b2: Array = jnp.asarray(0.999)

    def init(self, parameters: Weights) -> AdamState:
        return AdamState(mu=tree_map(lambda t: jnp.zeros_like(t), parameters),
                         nu=tree_map(lambda t: jnp.zeros_like(t), parameters),
                         count=jnp.zeros([], jnp.int32))

    def update(self,
               gradient: Weights,
               state: AdamState,
               parameters: Weights | None) -> tuple[Weights, AdamState]:
        mu = update_moment(gradient, state.mu, self.b1, 1)
        nu = update_moment_per_elem_norm(gradient, state.nu, self.b2, 2)
        count_inc = state.count + jnp.array(1, dtype=jnp.int32)
        mu_hat = bias_correction(mu, self.b1, count_inc)
        nu_hat = bias_correction(nu, self.b2, count_inc)
        gradient = tree_map(
            lambda m, v: m / (jnp.sqrt(v) + eps), mu_hat, nu_hat)
        gradient = tree_map(lambda m: m * -learning_rate, gradient)
        return gradient, AdamState(count=count_inc, mu=mu, nu=nu)

class SolutionState(NamedTuple):
    gradient_state: AdamState
    weights: Weights

def cli() -> None:
    gradient_transformation = Adam()
    weights = Weights(b1=jnp.array([0., 0., 0.], dtype=jnp.float32),
                      w1=jnp.array([[-0.667605, 0.3261746, -0.0785462]], dtype=jnp.float32),
                      b2=jnp.array([0.], dtype=jnp.float32),
                      w2=jnp.array([[0.464014], [-0.435685], [0.776788]], dtype=jnp.float32))
    gradient_state = AdamState(
        count=jnp.asarray(5),
        mu=Weights(b1=jnp.asarray([-15.569108, -8.185916, -18.872583]),
                   w1=jnp.asarray([[-6488.655, -5813.5786, -11111.309]]),
                   b2=jnp.asarray([-16.122942]),
                   w2=jnp.asarray([[-5100.7495], [-6862.2837], [-8967.359]])),
        nu=Weights(b1=jnp.asarray([7.211683, 1.9927658, 10.598419]),
                   w1=jnp.asarray([[1289447., 1035597.7, 3784737.8]]),
                   b2=jnp.asarray([7.749687]),
                   w2=jnp.asarray([[797202.94], [1442427.9], [2465843.5]])))

    state = SolutionState(gradient_state, weights)

    dataset = [2681.0000, 6406.0000, 2098.0000, 5384.0000, 5765.0000, 2273.0000] * 10
    for i, observation in enumerate(dataset):
        observation = jnp.asarray(observation)
        print(f"Iteration {i}")
        state = train_one_episode(observation, state, gradient_transformation)
    print(state)

@jit
def train_one_episode(observation: Array,
                      state: SolutionState,
                      gradient_transformation: Adam,
                      ) -> SolutionState:
    observations = jnp.reshape(observation, (1, 1))
    weights_bar, observation = _v_infer_gradient_and_value(observations, state.weights)
    new_weights_bar, new_gradient_state = gradient_transformation.update(
        weights_bar, state.gradient_state, state.weights)
    new_weights = tree_map(jnp.add, state.weights, new_weights_bar)
    return SolutionState(new_gradient_state, new_weights)

def _infer(observation: Array, weights: Weights) -> tuple[Array, Array]:
    seeker_loss = internal_infer_co(observation, weights)
    return seeker_loss, observation

def _infer_gradient_and_value(observation: Array, weights: Weights) -> tuple[Weights, Array]:
    bound_infer = partial(_infer, observation)
    f: Callable[[Array], tuple[Array, Array]] = grad(bound_infer, has_aux=True)
    return f(weights)

def _v_infer_gradient_and_value(observations: Array, weights: Weights) -> tuple[Weights, Array]:
    f = vmap(_infer_gradient_and_value, in_axes=(0, None), out_axes=(0, 0))
    weights_bars, infer_outputs = f(observations, weights)
    weights_bar = tree_map(partial(jnp.mean, axis=0), weights_bars)
    return weights_bar, infer_outputs

def odd_power(base: Array, exponent: Array) -> Array:
    return jnp.copysign(jnp.abs(base) ** exponent, base)

def energy(natural_explanation: Array, observation: Array, weights: Weights) -> Array:
    p = observation
    q = (dot(softplus(dot(natural_explanation, softplus(weights.w1)) + weights.b1),
             softplus(weights.w2))
         + weights.b2 + 1e-6 * odd_power(natural_explanation, jnp.array(3.0)))
    return dot(p - q, p) + 0.5 * jnp.sum(jnp.square(q)) - 0.5 * jnp.sum(jnp.square(p))

def internal_infer(observation: Array, weights: Weights) -> Array:
    minimizer = GradientDescent(energy, has_aux=False, maxiter=250, tol=0.001, acceleration=False)
    minimizer_result = minimizer.run(jnp.zeros(1), observation=observation, weights=weights)
    return jnp.sum(jnp.square(minimizer_result.params)) * 1e-1

_Weight_VJP = Callable[[Array], tuple[Array]]

@custom_vjp
def internal_infer_co(observation: Array, weights: Weights) -> Array:
    return internal_infer(observation, weights)

def internal_infer_co_fwd(observation: Array, weights: Weights) -> tuple[Array, _Weight_VJP]:
    return vjp(partial(internal_infer, observation), weights)

def internal_infer_co_bwd(weight_vjp: _Weight_VJP, _: Array) -> tuple[None, Array]:
    weights_bar, = weight_vjp(jnp.ones(()))
    return None, weights_bar

internal_infer_co.defvjp(internal_infer_co_fwd, internal_infer_co_bwd)

if __name__ == "__main__":
    cli()
mblondel commented 1 year ago

Thanks! I confirm that it's hanging at iteration 5 on my machine too but I'm not sure why yet. In the meantime, you can use LBFGS(energy, has_aux=False, maxiter=250, tol=0.001) as replacement, which seems to work properly.

mblondel commented 1 year ago

If I use this

def internal_infer(observation: Array, weights: Weights) -> Array:
    jax.debug.print("Call to GD")
    minimizer = GradientDescent(energy, has_aux=False, maxiter=250, tol=0.001,
                                acceleration=False)
    minimizer_result = minimizer.run(jnp.zeros(1), observation=observation, weights=weights)
    params = minimizer_result.params
    jax.debug.print("Done with GD")
    jax.debug.print("params: {}", params)
    return jnp.sum(jnp.square(params)) * 1e-1

I get

$ python gd_bug.py
Iteration 0
Call to GD
Done with GD
params: [1020.8626]
Iteration 1
Call to GD
Done with GD
params: [1571.4808]
Iteration 2
Call to GD
Done with GD
params: [875.84705]
Iteration 3
Call to GD
Done with GD
params: [1442.7856]
Iteration 4
Call to GD
Done with GD
params: [1485.6477]
Iteration 5
Call to GD
Done with GD
[hanging]

So I don't know what's going on.

If I remove @jit on train_one_episode, the bug goes away.

NeilGirdhar commented 1 year ago

So I don't know what's going on.

If I remove @jit on train_one_episode, the bug goes away.

I know, it's very strange! Does it seem like a Jax bug to you?

mblondel commented 1 year ago

What I find strange is that the call to GradientDescent.run finishes but accessing the result returned by it makes the program hang.