google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.71k stars 194 forks source link

Add Muon Optimizer to `contrib` #1126

Open leloykun opened 3 weeks ago

leloykun commented 3 weeks ago

Adds support for @KellerJordan Muon optimizer detailed here: https://github.com/KellerJordan/modded-nanogpt

image

The Muon optimizer does steepest descent under approximately Spectral norm--or more accurately, under Schatten-p norm for some large p (with Schatten-infty norm = Spectral norm).

leloykun commented 2 weeks ago

Hi all!

How do I restrict the tests to exclude vectors as weights?

vroulet commented 2 weeks ago

Hello @leloykun,

Looks like an interesting contribution thank you!

How do I restrict the tests to exclude vectors as weights?

I don't fully understand your question. Could you give more context?

Other questions/remarks:

How does this optimizer treat vector-shaped values? The"muon_iterator" could run on vectors but not return what you want, so how do you make the distinction? Should it take a mask to only apply on matrices? Should it raise value error?

Also could you add the mathematical description in the doc string? That would greatly help.

Finally, put references at the end of the docstring (we'll adopt this format with #1129)

Thank you!

leloykun commented 2 weeks ago

Hi @vroulet,

Muon is only defined for matrix-shaped values.

I'm thinking of raising an error when the input is vector-shaped, but where's the best place to put it? If there are other optimizers here that does this, can you point me to them?

vroulet commented 1 week ago

Hello @leloykun

Thanks again and sorry for the delay!