google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.71k stars 194 forks source link

schedule-free optimier: ensure it's possible to donate both the state and the params #1059

Closed enolan closed 1 month ago

enolan commented 2 months ago

Prior to this commit, with a schedule-free optimizer, if you tried to donate an entire train state, including both the parameters and a freshly initialized optimizer state, you'd get an "INVALID_ARGUMENT: Attempt to donate the same buffer twice in Execute()" error, because z was the same array as params. This commit fixes the issue and adds a test.

vroulet commented 2 months ago

Hello @enolan,

Thank you for catching that! Could you wait for #1060 to be merged ? I can ping you back when it's merged into head, you'll merge your PR with head, then I'll review.

enolan commented 2 months ago

sure no problem

vroulet commented 2 months ago

Hello @enolan, Your PR should be good to go. Just sync with head, check if the problem persists without your change (i.e. whether #1060 changed the behavior), then upload, and I'll approve. Thanks again for looking into this!

enolan commented 1 month ago

We should be good. I rebased and checked, the test still fails without the copy and passes with it.