google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.65k stars 181 forks source link

Port schedule_free optimizer to optax. Original pytorch repo: https://github.com/facebookresearch/schedule_free #911

Closed copybara-service[bot] closed 5 months ago

copybara-service[bot] commented 6 months ago

Port schedule_free optimizer to optax. Original pytorch repo: https://github.com/facebookresearch/schedule_free

Also add warmup_constant_schedule, which we recommend using with schedule_free

ameya98 commented 6 months ago

Nice! Thanks for this! One thing that should be mentioned in the docs is that the loss should be evaluated at the opt_state.x, not the params. (Not sure how much this matters in practice). Otherwise, this looks good to me.

adam-hartshorne commented 5 months ago

As it currently stands I don't think this will work seamlessly with equinox (https://github.com/patrick-kidger/equinox) which is widely used amongst JAX community.

vroulet commented 5 months ago

As it currently stands I don't think this will work seamlessly with equinox (https://github.com/patrick-kidger/equinox) which is widely used amongst JAX community.

@adam-hartshorne Can you explain why this does not work and maybe for further reference explain what are key criterions for integration of optimizers in equinox? Thank you!

adam-hartshorne commented 5 months ago

Equinox and the new experimental version of Flax* are both models which try to enable a more pythonic / class based approach (rather than the JAX purely functional / PyTree paradigm). Thus you are encouraged to define classes which are instantiated as a "model" object. They ultimately attempt to return everything as a normal PyTree, but it does require some additional handling of things like static variables and traversing PyTrees. Thus you end up with this design pattern of split / combine / filter. To ease this, they also have special methods for updating during optimisation, which handles all this.

Here is a simple example https://github.com/patrick-kidger/equinox/blob/main/examples/train_rnn.ipynb

As you can see it requires the use of decorators, @eqx.filter_value_and_grad, @eqx.filter_jit and the updates to the params are applied using eqx.apply_updates.

Now looking at your code the update_fn hard codes the use of optax_update and jax.tree_util.tree_map to the parameters. I don't believe this will be compatible with all models that have been built by inheriting from eqx.Module classes.

*there are a number of other attempts at this.

vroulet commented 5 months ago

Thanks for the summary @adam-hartshorne. I'm not sure I understand though: (i) this optimizer will still return updates that will be added to the optimizer with the library of your choice. All operations are done on pytrees of the form of updates (grads) that are normally handled by e.g. equinox (after all equinox uses optax too so it also works on updates/params that are usual pytrees without functions). No "model" is given here, nor operations that would call the model. (ii) the issue here would be in the definition of params in the update function (is it model or is it the params of the model?). This issue, if it exists, should not be new: numerous optimizers have the optional "params" argument like lookahead for example.

That said:

evanatyourservice commented 5 months ago

fwiw, I've ran into problems with mu_dtype=bf16, at least with small initializations for transformers (0.01 std normal). Haven't tried keeping only x or only z in bfloat16 though, maybe keeping only one in bfloat16 still works in which case it might make sense to pass in separate dtype args for each of them. Also, should x and z be cast back to state_dtype before creating next_state, and state_dtype be canonicalized near the top?

Edit: To clarify problems I run into with bf16 state, it seemingly stalls training altogether and progresses only very slowly.

Edit2 Seems my issue is coming from something other than dtype, will have to do more testing, works differently than ameya98's implementation but haven't taken time to find differences in the code and against fb research's repo. It does seem very sensitive to dtype, though, vs. regular momentum which most have become accustomed to keeping in bf16.

ameya98 commented 5 months ago

Any chance this could be merged soon? Wanted to experiment with it!

vroulet commented 5 months ago

This is almost good to go internally, I think we are reaching the last round of reviews and it should be merged (so probably early next week).

ameya98 commented 5 months ago

Hi, any update on this?

vroulet commented 5 months ago

I just pinged the author. Sorry for the delay