google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.71k stars 194 forks source link

Replaced seed with key in add_noise and noisy_sgd #1138

Open Tomas542 opened 1 week ago

Tomas542 commented 1 week ago

Replaced seed in add_noise with key in favor of jax.random-like style. Added (duplicated) example with add_noise from noisy_sgd. Changed seed to key in noisy_sgd. Imported chex for annotation purpose in _alias.py

google-cla[bot] commented 1 week ago

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

rdyro commented 1 week ago

These changes look good! Can you sign the google CLA to be able to run unit tests?

I like the change of placing the key first in the argument list. Can you take a look at these files as well:

It'd be good to rename the argument to key from seed in all these places and move the key argument to the beginning of the argument list - this might require touching up a couple of tests.

Tomas542 commented 1 week ago

I signed the CLA yesterday and went to bed. It's a great idea, but I think I'll do it this weekend. I also have a question (sorry if it's stupid, I've never done Pull Requests before): if I make a new commit with these changes to my branch - do they automatically go into this PR? Or is there some button I have to press?

rdyro commented 1 week ago

Awesome

if I make a new commit with these changes to my branch - do they automatically go into this PR? Or is there some button I have to press?

Not at all, thanks for the contribution! Exactly, any additional commits in your branch will be added to the PR here, in the end, you can also squish your commits into one before merging with something like:

$ git rebase -i main
$ # change all but one commit to "squish" leaving one of them (probably the first one) as "pick"
$ # if everything went well
$ git push -f (to overwrite history, which is necessary when squishing commits)
Tomas542 commented 1 week ago

Ok, and one more question - there are a lot of things that will change. Maybe we'll add key as an optional argument with an DeprecationWarning if seed is set, like Warning: argument seed will be removed (or replaced in cases where seed already PRNGKey) in the next release. Use key instead. But in this case key will move from the first argument, and we will have to make it of type Optional for now. And in the next-next major release I will remove/replace seed with key. Or is it better to change everything now? What do you think?

Tomas542 commented 6 days ago

Hi, haven't had time to make changes. I agree with the positions of the arguments. But I don't agree that the key should be set to the default value for two reasons:

rdyro commented 5 days ago

A default key value might be confusing giving determinism where people don't expect it, but noisy_sgd is a first-order optimizer and argument order consistency is definitely a value. Maybe we can keep key last, but required.

Something like:

def noisy_sgd(
    learning_rate: base.ScalarOrSchedule,
    eta: float = 0.01,
    gamma: float = 0.55,
    key: chex.PRNGKey | None = None,
) -> base.GradientTransformation:
  if key is None:
    raise ValueError("noisy_sgd optimizer requires specifying random key: noisy_sgd(..., key=random.key(0))")

@Tomas542 @vroulet what do you think?

vroulet commented 5 days ago

I like this idea. It may also smooth out backward compatibility issues with clear raised errors.

rdyro commented 3 days ago

A default key value might be confusing giving determinism where people don't expect it, but noisy_sgd is a first-order optimizer and argument order consistency is definitely a value. Maybe we can keep key last, but required.

Something like:

def noisy_sgd(
    learning_rate: base.ScalarOrSchedule,
    eta: float = 0.01,
    gamma: float = 0.55,
    key: chex.PRNGKey | None = None,
) -> base.GradientTransformation:
  if key is None:
    raise ValueError("noisy_sgd optimizer requires specifying random key: noisy_sgd(..., key=random.key(0))")

@Tomas542 @vroulet what do you think?

@Tomas542 could you add this change?

Tomas542 commented 1 day ago

Yeap, I can do this. Also I will try to change tests.

To summarize, we decided to replace all seed with key, make key the default value None and raise an error if it is None. As for the order of function arguments, we put key in the last position for optimizers and in the first position in some other functions, such as add_noise?

UPD: Also I would create another PR after we finish this discussion. And annotation would be Optional, cause a | b doesn't support in Python3.9

vroulet commented 1 day ago

Yes, good point for the annotation. And yes for the summary!