google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.65k stars 181 forks source link

Add AdEMAMix Optimizer #1057

Closed mathDR closed 1 week ago

mathDR commented 3 weeks ago

This PR adds the AdaMAMix optimizer from The arxiv preprint: THE ADEMAMIX OPTIMIZER: BETTER, FASTER, OLDER

Closes #1058

The docs have been updated, along with the relevant files. Furthermore, I ran a similar "test" to replicate the Rosenbrock figure from the paper: image

Currently the docstrings are implemented, but further descriptions should/could be added. I will reach out to the paper authors to assist with that (if they are willing).

vroulet commented 2 weeks ago

Thanks @mathDR for the contribution! Optimizers that have not passed the test of time (too recent like this one) are put in the contrib folder (see https://optax.readthedocs.io/en/latest/development.html#inclusion-criteria). Before continuing (making this optimizer in the contribs folder), can you wait for #1060 to be merged? I would like all optimizers to abide to common tests to avoid issues like #1038. PS: use an editor that enforces indents of two spaces not four.

mathDR commented 2 weeks ago

Okay thanks for the comments @vroulet. I will put an alert to check the status of #1060 and see when that is merged.

vroulet commented 1 week ago

Hello @mathDR

1060 has been merged.

If you want you can continue this PR. Put the algorithm in the contrib folder and make sure it passes the common tests.

mathDR commented 1 week ago

I closed this and deleted the forked repo (will be easier to fork again and start anew)