google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.68k stars 187 forks source link

multi_transform is inconsistent with a custom optimizer #1109

Open idnm opened 1 week ago

idnm commented 1 week ago

Hi! New to optax.

I wanted to implement the extra-gradient method (see e.g. here https://arxiv.org/abs/1901.08511v2), which is described mathematically by $x{k+1/2} = x{k}-\eta \nabla f(xk), \quad x{k+1} = x{k}-\eta \nabla f(x{k+1/2})$.

I'm not sure how to properly account for the midpoint step, here's how I did it.

    def extra_gradient_update(grads, params):

        # Midpoint
        mid_updates = jax.tree.map(lambda g: - learning_rate * g, grads)
        mid_params = optax.apply_updates(params, mid_updates)

        # Grads at midpoint
        mid_grads = jax.grad(func)(mid_params)

        # Final updates
        updates = jax.tree.map(lambda g: - learning_rate * g, mid_grads)

        return updates

The optimizer based on that update function works fine, but for some reason fails as a part of optax.multi_transform. Here is the full example attempting to performing a single update step for a function $f = x y$.

import jax
import jax.numpy as jnp
import optax

def extra_gradient_optimizer(func: callable, learning_rate: float):

    def extra_gradient_update(grads, params):

        # Midpoint
        mid_updates = jax.tree.map(lambda g: - learning_rate * g, grads)
        mid_params = optax.apply_updates(params, mid_updates)

        # Grads at midpoint
        mid_grads = jax.grad(func)(mid_params)

        # Final updates
        updates = jax.tree.map(lambda g: - learning_rate * g, mid_grads)

        return updates

    return optax.stateless(extra_gradient_update)

def f(params):
    return params['x'] * params['y']

params = {
    'x': jnp.array(1.0),
    'y': jnp.array(2.0)
}

opt = optax.multi_transform(
    {'xopt': extra_gradient_optimizer(f, 0.01),
     'yopt': extra_gradient_optimizer(f, -0.01)},
    {'x': 'xopt',
     'y': 'yopt'})

state = opt.init(params)
grads = jax.grad(f)(params)
opt.update(grads, state, params)

This results in an error that is traced back to the computation of $f$ itself TypeError: Only integer scalar arrays can be converted to a scalar index. If instead of the multi_transform I simply use opt=extra_gradient_optimizer(f, 0.01) the update works fine.

Is this a bug, or I'm not doing this the right way?

vroulet commented 3 days ago

Hello @idnm,

GradientTransformations are generally not well suited to include the computation of the gradient inside them (as the name suggests it's a transformation of gradients not an optimization oracle).

So here you can

  1. create a multitransform optimizer as you did
  2. chain it with a transform that keeps a step counter and a copy of the params. When the counter is e.g. odd the function makes a step using the previously stored params.
idnm commented 1 day ago

@vroulet Understood! And thanks for the suggestion, using the parameters from one step back seems to do the trick.