google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.65k stars 181 forks source link

schedule_free: fix broadcasting of scalar arrays to 1d arrays #1042

Closed n-gao closed 1 month ago

n-gao commented 1 month ago

Currently, the momentum is stored in a 1D array [b1] of shape (1,). We should store it instead in a scalar array ()to avoid broadcasting scalars to (1,) in schedule_free_eval_params.

Before:

opt = optax.contrib.schedule_free_adamw()
x = jnp.ones(())
state = opt.init(x)
optax.contrib.schedule_free_eval_params(state, x).shape
# (1,)

After:

opt = optax.contrib.schedule_free_adamw()
x = jnp.ones(())
state = opt.init(x)
optax.contrib.schedule_free_eval_params(state, x).shape
# ()
fabianp commented 1 month ago

your solution seems reasonable to me. However, there are now some doctest errors in optax/contrib/_schedule_free.py, probably due to them now being included in the docs.

n-gao commented 1 month ago

@fabianp I am quite unfamiliar with the docs. I thought this would be a simple change. Is there some documentation on this? Otherwise, I can also remove the doc changes.

fabianp commented 1 month ago

you might not need to build the docs (although if you wanted to, its described in the README). Just check the errors from the failing CI (https://github.com/google-deepmind/optax/actions/runs/10664454490/job/29555745356?pr=1042). as you can see, the issue seems to be that some examples in the docstrings use schedule_free_eval_params instead of the full name optax.contrib.schedule...

let me know if this doesn't make sense

n-gao commented 1 month ago

I haven't touched any line related to schedule_free_eval_params. All the other lines that docs/api/contrib.rst also don't use complete paths? Let me check if the tests pass if I remove the lines again.

n-gao commented 1 month ago

I don't get why the tests are failing. it works locally and the change seems unrelated. @fabianp do you have another idea?

fabianp commented 1 month ago

you can undo the changes in docs/api/contrib.rst if you want since they are orthogonal to this PR

n-gao commented 1 month ago

done

fabianp commented 1 month ago

can you also add a test showing that the new approach doesn't have the broadcasting problem?

n-gao commented 1 month ago

I added a test that fails before and succeeds after the PR

n-gao commented 1 month ago

Added a comment and changed the variable name. Though, this criticism probably applies to all the other tests in that file.

fabianp commented 1 month ago

excellent, thanks!