Open carlosgmartin opened 5 days ago
Hello @carlosgmartin,
You can simply made them support extra_args using with_extra_args_support.
import optax
from jax import numpy as jnp
params = jnp.zeros(5)
grads = jnp.ones_like(params)
value = 0.0
for opt in [
optax.sgd(1e-3),
optax.adam(1e-3),
optax.adabelief(1e-3),
optax.contrib.dadapt_adamw(),
optax.contrib.prodigy(),
]:
opt = optax.with_extra_args_support(opt)
opt_state = opt.init(params)
opt.update(grads, opt_state, params, value=value)
@vroulet Thanks! Just out of curiosity, from a design POV, what's the reason for having the with_extra_args_support wrapper, rather than just letting all optimizers receive extra args by default? That would eliminate the need to have a GradientTransformationExtraArgs separate from GradientTransformation.
I believe it was for backward compatibility. I fully agree that ideally the gradient transformation api should be
def init(grads_like, **extra_args):
...
def update(grads, state, **extra_args):
...
I don't know if a revamp of the API is possible at this stage unfortunately.
All
update
methods can receive aparams
argument. It isNone
by default.Most
update
methods can receive avalue
argument, but not all:This is inconvenient because it makes the interface non-uniform and requires one to call
update
in different ways according to the optimizer, which makes code more complex.Feature request: Allow all
update
methods to receive avalue
argument. It can beNone
by default.I can submit a PR editing
dadapt_adamw
andprodigy
accordingly.