google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.7k stars 193 forks source link

Allow all optimizer `update` methods to receive an optional `value` argument #1131

Open carlosgmartin opened 5 days ago

carlosgmartin commented 5 days ago

All update methods can receive a params argument. It is None by default.

Most update methods can receive a value argument, but not all:

import optax
from jax import numpy as jnp

params = jnp.zeros(5)
grads = jnp.ones_like(params)
value = 0.0

for opt in [
    optax.sgd(1e-3),
    optax.adam(1e-3),
    optax.adabelief(1e-3),
    optax.contrib.dadapt_adamw(),
    optax.contrib.prodigy(),
]:
    opt_state = opt.init(params)
    try:
        opt.update(grads, opt_state, params, value=value)
    except Exception as e:
        print(e)
dadapt_adamw.<locals>.update_fn() got an unexpected keyword argument 'value'
prodigy.<locals>.update_fn() got an unexpected keyword argument 'value'

This is inconvenient because it makes the interface non-uniform and requires one to call update in different ways according to the optimizer, which makes code more complex.

Feature request: Allow all update methods to receive a value argument. It can be None by default.

I can submit a PR editing dadapt_adamw and prodigy accordingly.

vroulet commented 3 days ago

Hello @carlosgmartin,

You can simply made them support extra_args using with_extra_args_support.

import optax
from jax import numpy as jnp

params = jnp.zeros(5)
grads = jnp.ones_like(params)
value = 0.0

for opt in [
    optax.sgd(1e-3),
    optax.adam(1e-3),
    optax.adabelief(1e-3),
    optax.contrib.dadapt_adamw(),
    optax.contrib.prodigy(),
]:
    opt = optax.with_extra_args_support(opt)
    opt_state = opt.init(params)
    opt.update(grads, opt_state, params, value=value)
carlosgmartin commented 1 day ago

@vroulet Thanks! Just out of curiosity, from a design POV, what's the reason for having the with_extra_args_support wrapper, rather than just letting all optimizers receive extra args by default? That would eliminate the need to have a GradientTransformationExtraArgs separate from GradientTransformation.

vroulet commented 1 day ago

I believe it was for backward compatibility. I fully agree that ideally the gradient transformation api should be

def init(grads_like, **extra_args):
  ...
def update(grads, state, **extra_args):
  ...

I don't know if a revamp of the API is possible at this stage unfortunately.