Closed n-gao closed 2 months ago
your solution seems reasonable to me. However, there are now some doctest errors in optax/contrib/_schedule_free.py
, probably due to them now being included in the docs.
@fabianp I am quite unfamiliar with the docs. I thought this would be a simple change. Is there some documentation on this? Otherwise, I can also remove the doc changes.
you might not need to build the docs (although if you wanted to, its described in the README). Just check the errors from the failing CI (https://github.com/google-deepmind/optax/actions/runs/10664454490/job/29555745356?pr=1042). as you can see, the issue seems to be that some examples in the docstrings use schedule_free_eval_params
instead of the full name optax.contrib.schedule...
let me know if this doesn't make sense
I haven't touched any line related to schedule_free_eval_params
. All the other lines that docs/api/contrib.rst
also don't use complete paths? Let me check if the tests pass if I remove the lines again.
I don't get why the tests are failing. it works locally and the change seems unrelated. @fabianp do you have another idea?
you can undo the changes in docs/api/contrib.rst if you want since they are orthogonal to this PR
done
can you also add a test showing that the new approach doesn't have the broadcasting problem?
I added a test that fails before and succeeds after the PR
Added a comment and changed the variable name. Though, this criticism probably applies to all the other tests in that file.
excellent, thanks!
Currently, the momentum is stored in a 1D array
[b1]
of shape(1,)
. We should store it instead in a scalar array()
to avoid broadcasting scalars to(1,)
inschedule_free_eval_params
.Before:
After: