google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.56k stars 166 forks source link

Add end_scale argument #975

Closed stefanocortinovis closed 1 month ago

stefanocortinovis commented 1 month ago

This PR adds an end_scale to contrib.reduce_on_plateau():

fabianp commented 1 month ago

Thanks @stefanocortinovis for the contribution! The code looks good to me but I'm not so sure about the name.

What would you think about end_factor so it references the fact that it's related to the factor keyword?

stefanocortinovis commented 1 month ago

Thanks for the quick feedback, @fabianp! I named it end_scale because it is effectively an upper/lower bound for the scale attribute of the ReduceLROnPlateauState class. Anyway, I'm happy to change it to whatever name you think it's best.

fabianp commented 1 month ago

ah, I see now. Thanks for the response, makes sense

fabianp commented 1 month ago

the reason I'm not convinced with end_scale is that scale relates more to the internal implemented than to the other parameters we expose. In other words, it be confusing for a user just reading the doc what scale refers to.

Anyway, I also couldn't come up with a better name, so let's give this 24h if merge it if nobody comes up with a better suggestion

vroulet commented 1 month ago

Thanks @stefanocortinovis ! Why would the factor be greater than 1? If we allow for a factor greater than 1 we are not "reducing" the lr and the name of the function may need to be changed. If we agree that the factor should be lower than 1:

stefanocortinovis commented 1 month ago

Thanks for the suggestion, @vroulet!

Yeah, I was also wondering that. Personally, I have only seen reduce_on_plateau() used to do what its name suggests. I opted for allowing scale > 1.0 for consistency with e.g. exponential_decay, which explicitly allows for rate > 1.0. It would probably makes sense to keep things simple as you recommend though.

In principle, I agree that min_lr would be ideal to communicate the meaning of the argument. However, since reduce_on_plateau() is used to scale the parameter updates directly, it's not exactly the learning rate applied to them. To make myself clear, standard gradient descent with lr = 1e-1 and end_scale = 1e-3 would imply scale >= 1e-3, and hence lr >= 1e-4. That is, end_scale and the actual minimum for lr might be different. Do you think that would be clear anyway?

vroulet commented 1 month ago

Then min_scale would be perfect no?

fabianp commented 1 month ago

Then min_scale would be perfect no?

+1, I like this

stefanocortinovis commented 1 month ago

Sounds good! I made the change from end_scale to min_scale and added an assertion to check the value of factor.