Open idnm opened 1 week ago
Hello @idnm,
GradientTransformations are generally not well suited to include the computation of the gradient inside them (as the name suggests it's a transformation of gradients not an optimization oracle).
So here you can
@vroulet Understood! And thanks for the suggestion, using the parameters from one step back seems to do the trick.
Hi! New to optax.
I wanted to implement the extra-gradient method (see e.g. here https://arxiv.org/abs/1901.08511v2), which is described mathematically by $x{k+1/2} = x{k}-\eta \nabla f(xk), \quad x{k+1} = x{k}-\eta \nabla f(x{k+1/2})$.
I'm not sure how to properly account for the midpoint step, here's how I did it.
The optimizer based on that update function works fine, but for some reason fails as a part of
optax.multi_transform
. Here is the full example attempting to performing a single update step for a function $f = x y$.This results in an error that is traced back to the computation of $f$ itself
TypeError: Only integer scalar arrays can be converted to a scalar index.
If instead of themulti_transform
I simply useopt=extra_gradient_optimizer(f, 0.01)
the update works fine.Is this a bug, or I'm not doing this the right way?