Closed copybara-service[bot] closed 5 months ago
Nice! Thanks for this! One thing that should be mentioned in the docs is that the loss should be evaluated at the opt_state.x
, not the params
. (Not sure how much this matters in practice). Otherwise, this looks good to me.
As it currently stands I don't think this will work seamlessly with equinox (https://github.com/patrick-kidger/equinox) which is widely used amongst JAX community.
As it currently stands I don't think this will work seamlessly with equinox (https://github.com/patrick-kidger/equinox) which is widely used amongst JAX community.
@adam-hartshorne Can you explain why this does not work and maybe for further reference explain what are key criterions for integration of optimizers in equinox? Thank you!
Equinox and the new experimental version of Flax* are both models which try to enable a more pythonic / class based approach (rather than the JAX purely functional / PyTree paradigm). Thus you are encouraged to define classes which are instantiated as a "model" object. They ultimately attempt to return everything as a normal PyTree, but it does require some additional handling of things like static variables and traversing PyTrees. Thus you end up with this design pattern of split / combine / filter. To ease this, they also have special methods for updating during optimisation, which handles all this.
Here is a simple example https://github.com/patrick-kidger/equinox/blob/main/examples/train_rnn.ipynb
As you can see it requires the use of decorators, @eqx.filter_value_and_grad, @eqx.filter_jit and the updates to the params are applied using eqx.apply_updates.
Now looking at your code the update_fn hard codes the use of optax_update and jax.tree_util.tree_map to the parameters. I don't believe this will be compatible with all models that have been built by inheriting from eqx.Module classes.
*there are a number of other attempts at this.
Thanks for the summary @adam-hartshorne. I'm not sure I understand though: (i) this optimizer will still return updates that will be added to the optimizer with the library of your choice. All operations are done on pytrees of the form of updates (grads) that are normally handled by e.g. equinox (after all equinox uses optax too so it also works on updates/params that are usual pytrees without functions). No "model" is given here, nor operations that would call the model. (ii) the issue here would be in the definition of params in the update function (is it model or is it the params of the model?). This issue, if it exists, should not be new: numerous optimizers have the optional "params" argument like lookahead for example.
That said:
fwiw, I've ran into problems with mu_dtype=bf16, at least with small initializations for transformers (0.01 std normal). Haven't tried keeping only x or only z in bfloat16 though, maybe keeping only one in bfloat16 still works in which case it might make sense to pass in separate dtype args for each of them. Also, should x and z be cast back to state_dtype before creating next_state, and state_dtype be canonicalized near the top?
Edit: To clarify problems I run into with bf16 state, it seemingly stalls training altogether and progresses only very slowly.
Edit2 Seems my issue is coming from something other than dtype, will have to do more testing, works differently than ameya98's implementation but haven't taken time to find differences in the code and against fb research's repo. It does seem very sensitive to dtype, though, vs. regular momentum which most have become accustomed to keeping in bf16.
Any chance this could be merged soon? Wanted to experiment with it!
This is almost good to go internally, I think we are reaching the last round of reviews and it should be merged (so probably early next week).
Hi, any update on this?
I just pinged the author. Sorry for the delay
Port schedule_free optimizer to optax. Original pytorch repo: https://github.com/facebookresearch/schedule_free
Also add warmup_constant_schedule, which we recommend using with schedule_free