google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.56k stars 165 forks source link

Documentation for cosine decay schedule #905

Open gjhuizing opened 3 months ago

gjhuizing commented 3 months ago

Hello,

The formula in the documentation for the cosine_decay_schedule (https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.cosine_decay_schedule) would suggest that the learning rate increases again after T steps.

A quick look at the code confirms this is not the case, but it may be good to write it explicitly, as in linear_schedule.

Happy to make a short PR! I also could propose a short formula/pseudocode for functions like piecewise_constant_schedule that do not have one.

Best

GJ

vroulet commented 3 months ago

Hello @gjhuizing,

Thanks for catching this! If you are willing to do such a PR that would be great!

gjhuizing commented 3 months ago

Great!