google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.56k stars 165 forks source link

Math pseudocode in description of SGD with Nesterov is incorrect #900

Closed satyenkale closed 3 months ago

satyenkale commented 3 months ago

The math pseudocode in the description of SGD with Nesterov is currently given as:

Screenshot 2024-04-03 at 4 12 26 PM

I believe this is incorrect. Apart from the circular definition of m_t in the case when nesterov = False, the definition of m_t itself should be corrected. The correct set of equations should be:

Screenshot 2024-04-03 at 5 07 24 PM

Or alternatively,

Screenshot 2024-04-03 at 5 05 34 PM

This can be verified from the equations (3) and (4) in Sutskever et al, On the importance of initialization and momentum in deep learning, 2013, with the change of variables m_t = -v_t/epsilon and alpha_t = epsilon.

vroulet commented 3 months ago

Ouh right, thank you very much for catching that @satyenkale! Could you make the correction with a quick pr?

satyenkale commented 3 months ago

Thanks! I created a PR (https://github.com/google-deepmind/optax/pull/901) but I am not sure if all the checks went through.

vroulet commented 3 months ago

Thank you again! It should go through. The bugs in https://github.com/google-deepmind/optax/pull/901 seem to be related to some changes in jax that broke some of our code. We'll investigate that.

satyenkale commented 3 months ago

Great, thanks!

On Wed, Apr 3, 2024 at 6:14 PM Vincent Roulet @.***> wrote:

Thank you again! It should go through. The bugs in #901 https://github.com/google-deepmind/optax/pull/901 seem to be related to some changes in jax that broke some of our code. We'll investigate that.

— Reply to this email directly, view it on GitHub https://github.com/google-deepmind/optax/issues/900#issuecomment-2035694138, or unsubscribe https://github.com/notifications/unsubscribe-auth/AL63RGS2BG5QX67XRI7Q2CLY3R5MLAVCNFSM6AAAAABFWAPTFWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAMZVGY4TIMJTHA . You are receiving this because you were mentioned.Message ID: @.***>

vroulet commented 3 months ago

Solved in #901. Thanks again!