google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.71k stars 194 forks source link

Avoid re-compilation and allow non-float32 dtypes in the zoom_linesearch #1108

Closed mirkobunse closed 1 month ago

mirkobunse commented 1 month ago

Problems

I experienced two problems that the zoom_linesearch exhibits when used within a JIT-compiled training step:

1) the training step is compiled twice because some dtypes of the linesearch's state change after the first iteration. 2) the compilation errors for dtypes other than float32.

I want to give a minimal example to reproduce both problems and I want to provide a work-around for the first one. The second problem cannot be worked around; hence, I propose this PR to ultimately solve both problems.

Minimal working example for the re-compilation issue

I sample some random data and create an LBGFS optimizer with zoom linesearch. The training step is JIT-compiled and a print statement informs us about the triggering of a compilation (because this statement is silent in the compiled variant of the function).

import jax.numpy as jnp
import jax
import optax
from time import time

key = jax.random.PRNGKey(42) # similar to the getting started guide
x = jax.random.normal(key, (16, 2))
y = jnp.sum(x * 0.5, axis=-1)
optimizer = optax.lbfgs()

@jax.jit
def training_step(params, state):
  print("I'm being compiled for", state[-1].info)
  def loss_fn(params):
    y_pred = (params @ x.T).T # a linear model
    loss = jnp.mean(optax.l2_loss(y_pred, y))
    return loss
  loss, grad = jax.value_and_grad(loss_fn)(params)
  updates, state = optimizer.update(
    grad,
    state,
    params,
    value = loss,
    grad = grad,
    value_fn = loss_fn,
  )
  params = optax.apply_updates(params, updates)
  return params, state

params = jnp.array([0.0, 0.0])
state = optimizer.init(params)

t0 = time()
for _ in range(10):
  params, state = training_step(params, state)
print(f"Took {time() - t0} sec")

The output of this script is the following. We recognize that the function is compiled twice and that, between iterations, dtypes have changed from weak types to strong types.

I'm being compiled for ZoomLinesearchInfo(num_linesearch_steps=Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, decrease_error=Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, curvature_error=Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
I'm being compiled for ZoomLinesearchInfo(num_linesearch_steps=Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>, decrease_error=Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, curvature_error=Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
Took 0.7838089466094971 sec

Workaround for the re-compilation issue

We can change the initial state to have strong dtypes right from the beginning:

state = optimizer.init(params) # as before

state = ( # replace occurences of weak_type=True with strongly typed traces.
  *state[:-1], # all but the last item of this state are good
  optax._src.linesearch.ScaleByZoomLinesearchState(
    learning_rate = state[-1].learning_rate, # will not become strong -> can stay weak
    value = jnp.asarray(state[-1].value, dtype=state[-1].value.dtype), # infer dtype
    grad = state[-1].grad, # is strong already
    info = optax._src.linesearch.ZoomLinesearchInfo(
        num_linesearch_steps = jnp.asarray(
          state[-1].info.num_linesearch_steps,
          dtype = jnp.int32 # the original dtype is not the final one -> fix beforehand
        ),
        decrease_error = jnp.asarray(
          state[-1].info.decrease_error,
          dtype = jnp.float32
        ),
        curvature_error = jnp.asarray(
          state[-1].info.curvature_error,
          dtype = jnp.float32
        ),
    ),
  )
)

As desired, the training step is now compiled only once.

I'm being compiled for ZoomLinesearchInfo(num_linesearch_steps=Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>, decrease_error=Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, curvature_error=Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
Took 0.4254724979400635 sec

Minimal working example for non-default dtypes

In the above example, we make two slight changes,

- x = jax.random.normal(key, (16, 2))
+ x = jax.random.normal(key, (16, 2), dtype=jnp.bfloat16)

and

- params = jnp.array([0.0, 0.0])
+ params = jnp.array([0.0, 0.0], dtype=jnp.bfloat16)

Unfortunately, these changes break the compilation entirely. Since this error, shown in the following, appears in the internals of the zoom linesearch, I do not see an easy work-around for this problem.

...

  File "/home/bunse/Repos/optax/optax/_src/linesearch.py", line 1167, in step_fn
    new_state = jax.lax.cond(
                ^^^^^^^^^^^^^
TypeError: true_fun and false_fun output must have identical types, got
ZoomLinesearchState(count='ShapedArray(int32[])', params='ShapedArray(bfloat16[2])', updates='ShapedArray(bfloat16[2])', stepsize_guess='ShapedArray(float32[], weak_type=True)', stepsize='ShapedArray(float32[])', value='ShapedArray(bfloat16[])', grad='ShapedArray(bfloat16[2])', slope='ShapedArray(bfloat16[])', value_init='ShapedArray(bfloat16[])', slope_init='ShapedArray(bfloat16[])', decrease_error='DIFFERENT ShapedArray(float32[]) vs. ShapedArray(bfloat16[])', curvature_error='ShapedArray(bfloat16[])', error='DIFFERENT ShapedArray(float32[]) vs. ShapedArray(bfloat16[])', interval_found='ShapedArray(bool[])', done='ShapedArray(bool[])', failed='ShapedArray(bool[])', low='ShapedArray(float32[])', value_low='ShapedArray(bfloat16[])', slope_low='ShapedArray(bfloat16[])', high='ShapedArray(float32[])', value_high='ShapedArray(bfloat16[])', slope_high='ShapedArray(bfloat16[])', cubic_ref='ShapedArray(float32[], weak_type=True)', value_cubic_ref='ShapedArray(bfloat16[])', safe_stepsize='ShapedArray(float32[])', safe_value='ShapedArray(bfloat16[])', safe_grad='ShapedArray(bfloat16[2])')

Solution: this PR

I suggest the following improvements:

The specification of a value dtype works as follows:

- optimizer = optax.lbfgs()
+   optimizer = optax.lbfgs(
+   linesearch = optax.scale_by_zoom_linesearch(15, value_dtype=jnp.bfloat16)
+ )
google-cla[bot] commented 1 month ago

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

vroulet commented 1 month ago

Hello @mirkobunse ,

Thanks for the fix! There is a larger issue in optax: any optimizer with mixed precision (different precision with gradients and parameters) may compile twice. A fix would probably be to change the signature of the init functions rather than adding arguments to the definition of the optimizer. We will discuss that internally soon and I'll keep you posted.

If this is blocking your research you may also consider optimistix which carefully took care of some recompilation issues