Open NeilGirdhar opened 1 year ago
@zaccharieramzi Since you had a recent commit about LBFGS, would you kindly be able to let me know if my error could be a failed line search in LBFGS?
Hi @NeilGirdhar : looking at your linked MWE, I am not sure where in your code you use LBFGS.
To know whether a failed line search is happening when using LBFGS, you can always check the attribute of the state failed_linesearch
, which should tell you whether line search has failed or not.
You can also just enable stop_if_linesearch_fails
when instantiating the solver.
However, I doubt this is a line search error: what I corrected was not something that created deadlocks, just longer optimization runs. How many iterations do you need to reach a freeze? Maybe a stupid question, but how do you know it's a freeze and not just the algorithm being super slow?
I saw in your older issue that disabling jit solves the problem so I couldn't really give you smart debugging advice unfortunately, I am really new to jax so it's difficult for me to help in that regard.
Also pinging @mblondel who might know more on this since he developed LBFGS
Hi @NeilGirdhar : looking at your linked MWE, I am not sure where in your code you use LBFGS.
Oh, my mistake, I switched it to GradientDescent
for the time being.
stop_if_linesearch_fails
Thanks, I'll keep this in mind when I switch back to LBFGS.
How many iterations do you need to reach a freeze? Maybe a stupid question, but how do you know it's a freeze and not just the algorithm being super slow?
Five iterations of Jax code. I'm not sure how many iterations of the solver, but maxiter
is set to 250. I have let it run for a very long time. The reason I don't think it's just slow is because that would mean that JaxOpt is orders of magnitude slower than my naively implemented gradient descent, and it wouldn't explain the problem disappearing when switching to 64 bits.
Something very strange is happening, either in Jax's XLA compiler, or in JaxOpt's loops.
I saw in your older issue that disabling jit solves the problem so I couldn't really give you smart debugging advice unfortunately, I am really new to jax so it's difficult for me to help in that regard.
No worries, I really appreciate the speedy reply. I've been struggling for a month now!
Ok maybe you can try searching for the minimal setting of maxiter
that makes this bug.
Also it could still be that the algorithm is super slow in 32bits because it doesn't reach the tolerance for some floating point error reason, and does in 64bits (but that's just a wild guess).
But does your code work when using jaxopt's GradientDescent? (also pay attention to the fact that jaxopt's GradientDescent is by default accelerated)
Ok maybe you can try searching for the minimal setting of maxiter that makes this bug.
I got it down to maxiter=11
.
Also it could still be that the algorithm is super slow in 32bits because it doesn't reach the tolerance for some floating point error reason, and does in 64bits (but that's just a wild guess).
Interesting idea, but my tolerance is 0.001
. Is that low enough to cause problems?
But does your code work when using jaxopt's GradientDescent? (also pay attention to the fact that jaxopt's GradientDescent is by default accelerated)
It doesn't work with GradientDescent
.
(also pay attention to the fact that jaxopt's GradientDescent is by default accelerated)
Okay, but I'm running on the CPU only. Unless you mean jitted? If so, that's related to the bug since turning of the JIT prevents the bug.
no no with this tolerance it should be fine.
Haaa ok it doesn't work with GradientDescent
, I thought it did. Have you tried removing all "options" from the solver:
maxls
the maximum number of line search iterations)acceleration=False
to get rid of the FISTA stepWe might pin down the issue if you try with either of these (or both) turned off
setting acceleration=False to get rid of the FISTA step
I've already set it to false.
taking a constant stepsize or the schedule of your choice (or maybe just reducing maxls the maximum number of line search iterations)
Lowering maxls
from its default of 15 down to 5 does make the program get farther but it eventually freezes. At 2 is succeeds. At 4, it produces a bunch of nans and doesn't freeze.
Setting a constant step size of 1e-3 allows the program to complete normally.
I guess one thing that would definitely cause this is if the minimization function f
returning nan would cause JaxOpt's gradient descent to loop forever. I would consider that a bug in JaxOpt.
It shouldn't because every while_loop
has maxiter
, but indeed you could test this in a very simple setup with a function that returns nan
randomly.
Unfortunately since this seems related to the very specific fista_linesearch
implemented here, I am afraid I am not able to help much further at this stage. I guess @mblondel will have much more to say (maybe after the ICML deadline).
@zaccharieramzi Okay, thanks so much for your kind help !
@mblondel Would you happen to have some time to look at this blocking bug for me?
Do you have a reproducing example in a single script that I can execute?
@mblondel It is in a single script. (I linked it on the issue.) However, if you do the poetry installation, then it'll guarantee that you have the same environment too.
Maybe I missed something but I don't see any main function in this file.
@mblondel Oh, I'm sorry, cli
is the entry point. You can run simply append cli()
to the file to run it.
Did you try running your program in float64 precision?
@mblondel Yes, that works.
I appreciate your taking a look. I don't think running in 64-bit precision is a reasonable workaround since I don't think Jax should ever lock up like this. Something has gone seriously wrong either with Jax's compilation or in JaxOpt.
The fact that it works in float64 suggests that there might a be a way to work around the stability issue by writing the objective function differently. I'm not if the bug comes from JAXopt, since as you said, our loops all have a maximum number of iterations.
The fact that it works in float64 suggests that there might a be a way to work around the stability issue by writing the objective function differently.
How does a "stability issue" cause a deadlock?
I'm not if the bug comes from JAXopt,
Right, I'm not sure either. But the alternative is that it comes from Jax, which means some kind of compiler bug. Were you able to reproduce it?
@mblondel Sorry to keep bugging you on this. I really want to use JaxOpt, but I'm blocked on this bug. Do you think this is likely to be a problem with Jax? Were you able to reproduce this bug?
I can give it a try but I need a single script (as short as possible) that I can run and without additional dependencies.
@mblondel Okay thanks, I'll start putting that together for you.
No dependencies except Jax and JaxOpt.
from __future__ import annotations
from collections.abc import Callable
from functools import partial
from typing import NamedTuple
import jax.numpy as jnp
from jax import Array, custom_vjp, grad, jit, vjp, vmap
from jax.lax import dot
from jax.nn import softplus
from jax.tree_util import tree_map
from jaxopt import GradientDescent
class Weights(NamedTuple):
b1: Array
w1: Array
b2: Array
w2: Array
class AdamState(NamedTuple):
count: Array
mu: Weights
nu: Weights
def update_moment(updates, moments, decay, order):
return tree_map(lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments)
def update_moment_per_elem_norm(updates, moments, decay, order):
def orderth_norm(g):
if jnp.isrealobj(g):
return g ** order
else:
half_order = order / 2
# JAX generates different HLO for int and float `order`
if half_order.is_integer():
half_order = int(half_order)
return jnp.square(g) ** half_order
return tree_map(lambda g, t: (1 - decay) * orderth_norm(g) + decay * t, updates, moments)
def bias_correction(moment, decay, count):
bias_correction_ = 1 - decay**count
# Perform division in the original precision.
return tree_map(lambda t: t / bias_correction_.astype(t.dtype), moment)
b1 = 0.9
b2 = 0.999
eps = 1e-8
learning_rate = 1e-2
class Adam(NamedTuple):
b1: Array = jnp.asarray(0.9)
b2: Array = jnp.asarray(0.999)
def init(self, parameters: Weights) -> AdamState:
return AdamState(mu=tree_map(lambda t: jnp.zeros_like(t), parameters),
nu=tree_map(lambda t: jnp.zeros_like(t), parameters),
count=jnp.zeros([], jnp.int32))
def update(self,
gradient: Weights,
state: AdamState,
parameters: Weights | None) -> tuple[Weights, AdamState]:
mu = update_moment(gradient, state.mu, self.b1, 1)
nu = update_moment_per_elem_norm(gradient, state.nu, self.b2, 2)
count_inc = state.count + jnp.array(1, dtype=jnp.int32)
mu_hat = bias_correction(mu, self.b1, count_inc)
nu_hat = bias_correction(nu, self.b2, count_inc)
gradient = tree_map(
lambda m, v: m / (jnp.sqrt(v) + eps), mu_hat, nu_hat)
gradient = tree_map(lambda m: m * -learning_rate, gradient)
return gradient, AdamState(count=count_inc, mu=mu, nu=nu)
class SolutionState(NamedTuple):
gradient_state: AdamState
weights: Weights
def cli() -> None:
gradient_transformation = Adam()
weights = Weights(b1=jnp.array([0., 0., 0.], dtype=jnp.float32),
w1=jnp.array([[-0.667605, 0.3261746, -0.0785462]], dtype=jnp.float32),
b2=jnp.array([0.], dtype=jnp.float32),
w2=jnp.array([[0.464014], [-0.435685], [0.776788]], dtype=jnp.float32))
gradient_state = AdamState(
count=jnp.asarray(5),
mu=Weights(b1=jnp.asarray([-15.569108, -8.185916, -18.872583]),
w1=jnp.asarray([[-6488.655, -5813.5786, -11111.309]]),
b2=jnp.asarray([-16.122942]),
w2=jnp.asarray([[-5100.7495], [-6862.2837], [-8967.359]])),
nu=Weights(b1=jnp.asarray([7.211683, 1.9927658, 10.598419]),
w1=jnp.asarray([[1289447., 1035597.7, 3784737.8]]),
b2=jnp.asarray([7.749687]),
w2=jnp.asarray([[797202.94], [1442427.9], [2465843.5]])))
state = SolutionState(gradient_state, weights)
dataset = [2681.0000, 6406.0000, 2098.0000, 5384.0000, 5765.0000, 2273.0000] * 10
for i, observation in enumerate(dataset):
observation = jnp.asarray(observation)
print(f"Iteration {i}")
state = train_one_episode(observation, state, gradient_transformation)
print(state)
@jit
def train_one_episode(observation: Array,
state: SolutionState,
gradient_transformation: Adam,
) -> SolutionState:
observations = jnp.reshape(observation, (1, 1))
weights_bar, observation = _v_infer_gradient_and_value(observations, state.weights)
new_weights_bar, new_gradient_state = gradient_transformation.update(
weights_bar, state.gradient_state, state.weights)
new_weights = tree_map(jnp.add, state.weights, new_weights_bar)
return SolutionState(new_gradient_state, new_weights)
def _infer(observation: Array, weights: Weights) -> tuple[Array, Array]:
seeker_loss = internal_infer_co(observation, weights)
return seeker_loss, observation
def _infer_gradient_and_value(observation: Array, weights: Weights) -> tuple[Weights, Array]:
bound_infer = partial(_infer, observation)
f: Callable[[Array], tuple[Array, Array]] = grad(bound_infer, has_aux=True)
return f(weights)
def _v_infer_gradient_and_value(observations: Array, weights: Weights) -> tuple[Weights, Array]:
f = vmap(_infer_gradient_and_value, in_axes=(0, None), out_axes=(0, 0))
weights_bars, infer_outputs = f(observations, weights)
weights_bar = tree_map(partial(jnp.mean, axis=0), weights_bars)
return weights_bar, infer_outputs
def odd_power(base: Array, exponent: Array) -> Array:
return jnp.copysign(jnp.abs(base) ** exponent, base)
def energy(natural_explanation: Array, observation: Array, weights: Weights) -> Array:
p = observation
q = (dot(softplus(dot(natural_explanation, softplus(weights.w1)) + weights.b1),
softplus(weights.w2))
+ weights.b2 + 1e-6 * odd_power(natural_explanation, jnp.array(3.0)))
return dot(p - q, p) + 0.5 * jnp.sum(jnp.square(q)) - 0.5 * jnp.sum(jnp.square(p))
def internal_infer(observation: Array, weights: Weights) -> Array:
minimizer = GradientDescent(energy, has_aux=False, maxiter=250, tol=0.001, acceleration=False)
minimizer_result = minimizer.run(jnp.zeros(1), observation=observation, weights=weights)
return jnp.sum(jnp.square(minimizer_result.params)) * 1e-1
_Weight_VJP = Callable[[Array], tuple[Array]]
@custom_vjp
def internal_infer_co(observation: Array, weights: Weights) -> Array:
return internal_infer(observation, weights)
def internal_infer_co_fwd(observation: Array, weights: Weights) -> tuple[Array, _Weight_VJP]:
return vjp(partial(internal_infer, observation), weights)
def internal_infer_co_bwd(weight_vjp: _Weight_VJP, _: Array) -> tuple[None, Array]:
weights_bar, = weight_vjp(jnp.ones(()))
return None, weights_bar
internal_infer_co.defvjp(internal_infer_co_fwd, internal_infer_co_bwd)
if __name__ == "__main__":
cli()
Thanks! I confirm that it's hanging at iteration 5 on my machine too but I'm not sure why yet. In the meantime, you can use LBFGS(energy, has_aux=False, maxiter=250, tol=0.001)
as replacement, which seems to work properly.
If I use this
def internal_infer(observation: Array, weights: Weights) -> Array:
jax.debug.print("Call to GD")
minimizer = GradientDescent(energy, has_aux=False, maxiter=250, tol=0.001,
acceleration=False)
minimizer_result = minimizer.run(jnp.zeros(1), observation=observation, weights=weights)
params = minimizer_result.params
jax.debug.print("Done with GD")
jax.debug.print("params: {}", params)
return jnp.sum(jnp.square(params)) * 1e-1
I get
$ python gd_bug.py
Iteration 0
Call to GD
Done with GD
params: [1020.8626]
Iteration 1
Call to GD
Done with GD
params: [1571.4808]
Iteration 2
Call to GD
Done with GD
params: [875.84705]
Iteration 3
Call to GD
Done with GD
params: [1442.7856]
Iteration 4
Call to GD
Done with GD
params: [1485.6477]
Iteration 5
Call to GD
Done with GD
[hanging]
So I don't know what's going on.
If I remove @jit
on train_one_episode
, the bug goes away.
So I don't know what's going on.
If I remove
@jit
ontrain_one_episode
, the bug goes away.
I know, it's very strange! Does it seem like a Jax bug to you?
What I find strange is that the call to GradientDescent.run
finishes but accessing the result returned by it makes the program hang.
I have a mysterious deadlock that is introduced when I replaced my gradient descent with JaxOpt. I reported it as a Jax bug, but could it be a JaxOpt bug? The MWE is about 100 lines, and can be reproduced using these instructions. I did comb through the JaxOpt source code, but it seems that all while loops are guarded by maximum iteration conditions. Is there something I could be missing?