Closed timothyas closed 1 month ago
Hello @timothyas,
Thanks for reaching out! You can use optax.tree_utils.tree_get to access the learning rate. See
Beautiful, thanks @vroulet! That's exactly what I needed. I'm running optax version 0.2.1 which does not appear to have optax.tree_utils.tree_get
, but for this case I can just grab the second optimizer state.
I feel like it would be very helpful to mention something about this in this section of the docs. For example after what's already there, something like what's below. If you agree I'd be happy to add it as a simple PR.
In this example, hyperparameters could be monitored or modified in either of the gradient transformations optax.adamw
or optax.clip
with optax.inject_hyperparams
. For example, to access the learning rate, the code would be modified as follows
optimizer = optax.chain(
optax.clip(1.0),
optax.inject_hyperparams(optax.adamw)(learning_rate=schedule),
)
However, because of the optax.chain
call, opt_state
is now a tuple, where the first element refers to the optax.clip
transformation and the second refers to the optax.adamw
transformation. So the learning rate could be diagnosed with optax.tree_utils.tree_get
as follows (for instance in the fit
function):
lr = optax.tree_utils.tree_get(opt_state, "learning_rate")
or alternatively it can be accessed manually via the second element of the opt_state
:
lr = opt_state[1].hyperparams["learning_rate"]
First off, thanks to the optax crew for this package! I am relatively new to JAX and Optax, and I have a bit of a rookie question regarding how to combine various optimizer/gradient transformation options. If you think it's useful (and I didn't just somehow miss this in the documentation), I'd be happy to contribute the resulting code ... but let's just see I guess.
I would like to create an optimizer that does the following:
optax.warmup_cosine_decay_schedule
optax.inject_hyperparams
to monitor the learning rate (at least)The code I would expect (hope?) to work, but does not
I've tried several variants of this, but reached the conclusion that I just don't understand what's going on.
The brute force approach that works
But it seems weird to keep a separate
optimizer
andclipper
.Any help here would be greatly appreciated. Thanks in advance!