google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.71k stars 194 forks source link

Adds AdeMAMix Optimizer to `contrib` #1104

Open mathDR opened 1 month ago

mathDR commented 1 month ago

This PR adds the AdaMAMix optimizer from The arxiv preprint: THE ADEMAMIX OPTIMIZER: BETTER, FASTER, OLDER

Closes #1058

The docs have been updated, along with the relevant files.

Currently the docstrings are implemented, but further descriptions should/could be added. I will reach out to the paper authors to assist with that (if they are willing).

mathDR commented 1 month ago

Hey @zcharles8 in rereading the paper I subsequently discovered the authors have a full jax implemenentation here.

I will email them and see if they have any qualms about me using aspects of their code as part of this PR.

Do you know if there would be licensing problems with this approach?

Please advise.

vroulet commented 1 month ago

Do you know if there would be licensing problems with this approach?

Great question, probably there would from reading Apple's license... Let me send a mail to one of the authors I know.

mathDR commented 1 month ago

Thanks @vroulet I also tried sending an email to all 3 authors (but had to google their contact info so those might be outdated)

Basically I just asked if they are keen to get a version into contrib I would continue my PR but use their docstrings (with attribution)

But if they don't want a version in optax and want users to use their version, I would close the PR.

mathDR commented 1 month ago

Emailed with Matteo Pagliardini and he wrote:

Very happy to hear that you found the work useful and thanks for the PR. Having AdEMAMix as part of Optax would be great. Feel free to proceed with the PR and use our docstrings. 

So I went ahead and completed the PR.

One open question: AdEMAMix uses a pretty bespoke scheduler for b3. The alpha scheduler can be implemented via the vanilla linear_schedule but I couldn't find a drop in replacement for the b3 scheduler.

Both are now used in the rosenbrock example, but I didn't know if we wanted to add a new scheduler type for completeness?

zcharles8 commented 1 month ago

One open question: AdEMAMix uses a pretty bespoke scheduler for b3. The alpha scheduler can be implemented via the vanilla linear_schedule but I couldn't find a drop in replacement for the b3 scheduler.

Both are now used in the rosenbrock example, but I didn't know if we wanted to add a new scheduler type for completeness?

I think that adding the schedulers to the library (alpha_scheduler and b3_scheduler) is totally reasonable, if for no other reason than to make it easy for someone to directly use the schedulers from the paper.

FWIW I think that this PR is really good (ie. things like typing information) and would prefer to have it pushed through (nice work @mathDR !)

mathDR commented 1 month ago

Okay thanks for the tip to render the docs. That allowed me to fix a lot of weirdness (and a LaTeX error!).

I think this is good to go!

mathDR commented 3 weeks ago

So @vroulet I think you have to merge it? Or does another authorized user have to do it?

vroulet commented 3 weeks ago

The PR needs to be approved by one of the internal owners of the package (as I did), then copybara automatically syncs the PR with the internal code, produces a snapshot for another maintainer to check and once that other maintainer gives his/her approval the PR is merged. (That's why the PRs often take a bit of time to be merged even after they get approved).

fabianp commented 3 weeks ago

thanks for the contribution! Note that you'll need to edit gallery.rst for your example to display in https://optax.readthedocs.io/en/latest/gallery.html

mathDR commented 3 weeks ago

Okay great. I updated gallery.rst and added a png thumbnail to the images/ directory. I also added a colab link, as I saw the other examples did the same.

Note, when I render the docs locally, my example renders a Keyboard Interrupt that doesn't exist in the original jupyter notebook.

Is this an artifact of sphinx attempting to render the notebook? Hopefully this is a problem with my local env and will not persist in main

fabianp commented 3 weeks ago

that's because your test is taking too long to run. you can either make it quicker, or add it to nb_execution_excludepatterns in https://github.com/google-deepmind/optax/blob/main/docs/conf.py (but then it won't be run as part of the test suite)