google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.56k stars 166 forks source link

How to Cleanly Specify Optimizer, Schedule, & Gradient Clipping with inject_hyperparams #964

Closed timothyas closed 1 month ago

timothyas commented 1 month ago

First off, thanks to the optax crew for this package! I am relatively new to JAX and Optax, and I have a bit of a rookie question regarding how to combine various optimizer/gradient transformation options. If you think it's useful (and I didn't just somehow miss this in the documentation), I'd be happy to contribute the resulting code ... but let's just see I guess.

I would like to create an optimizer that does the following:

The code I would expect (hope?) to work, but does not

I've tried several variants of this, but reached the conclusion that I just don't understand what's going on.

import optax
import jax.numpy as jnp

def custom_optimizer():
    schedule = optax.warmup_cosine_decay_schedule(
        init_value=0.,
        peak_value=1e-3,
        warmup_steps=10,
        decay_steps=100,
        end_value=0.,
    )
    adam = optax.inject_hyperparams(optax.adamw)(
        learning_rate=schedule,
    )
    return optax.chain(
        optax.clip_by_global_norm(1.),
        adam,
    )

grads = {'w': jnp.full((5, 5), 0.1), 'b': jnp.full((5), 0.1)} # dummy
params = {'w': jnp.full((5, 5), 0.1), 'b': jnp.full((5), 0.1)} # dummy

optimizer = custom_optimizer()
opt_state = optimizer.init(params)

updates, new_opt_state = optimizer.update(grads, opt_state, params)
print(new_opt_state.hyperparams['learning_rate'])

The brute force approach that works

But it seems weird to keep a separate optimizer and clipper.

import optax
import jax.numpy as jnp

def custom_optimizer():
    schedule = optax.warmup_cosine_decay_schedule(
        init_value=0.,
        peak_value=1e-3,
        warmup_steps=10,
        decay_steps=100,
        end_value=0.,
    )
    adam = optax.inject_hyperparams(optax.adamw)(
        learning_rate=schedule,
    )
    return adam

grads = {'w': jnp.full((5, 5), 0.1), 'b': jnp.full((5), 0.1)} # dummy
params = {'w': jnp.full((5, 5), 0.1), 'b': jnp.full((5), 0.1)} # dummy

optimizer = custom_optimizer()
opt_state = optimizer.init(params)

clipper = optax.clip_by_global_norm(1.)
clip_state = clipper.init(params)

new_grads, clip_state = clipper.update(grads, clip_state)
updates, new_opt_state = optimizer.update(new_grads, opt_state, params)
print(new_opt_state.hyperparams['learning_rate'])

Any help here would be greatly appreciated. Thanks in advance!

vroulet commented 1 month ago

Hello @timothyas,

Thanks for reaching out! You can use optax.tree_utils.tree_get to access the learning rate. See

961. The issue you have is that the state returned by the chained optimizer is a tuple of two states (first for the clipping, second for Adamw). You may then get the learning rate as state[1].hyperparams['learning_rate']. Rather than memorizing the structure of the chained optimizer, optax.tree_utils.tree_get search in the pytree for the entry corresponding to 'learning_rate'

timothyas commented 1 month ago

Beautiful, thanks @vroulet! That's exactly what I needed. I'm running optax version 0.2.1 which does not appear to have optax.tree_utils.tree_get, but for this case I can just grab the second optimizer state.

timothyas commented 1 month ago

I feel like it would be very helpful to mention something about this in this section of the docs. For example after what's already there, something like what's below. If you agree I'd be happy to add it as a simple PR.

In this example, hyperparameters could be monitored or modified in either of the gradient transformations optax.adamw or optax.clip with optax.inject_hyperparams. For example, to access the learning rate, the code would be modified as follows

optimizer = optax.chain(
  optax.clip(1.0),
  optax.inject_hyperparams(optax.adamw)(learning_rate=schedule),
)

However, because of the optax.chain call, opt_state is now a tuple, where the first element refers to the optax.clip transformation and the second refers to the optax.adamw transformation. So the learning rate could be diagnosed with optax.tree_utils.tree_get as follows (for instance in the fit function):

lr = optax.tree_utils.tree_get(opt_state, "learning_rate")

or alternatively it can be accessed manually via the second element of the opt_state:

lr = opt_state[1].hyperparams["learning_rate"]