google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.65k stars 181 forks source link

Add Optimistic Adam #1081

Open carlosgmartin opened 2 days ago

carlosgmartin commented 2 days ago

Feature request: Add Optimistic Adam, an optimistic variant of Adam introduced in [1]. Among other things, it addresses the issue of limit cycling behavior in GAN training.

Perhaps it can be implemented by combining scale_by_adam with scale_by_optimistic_gradient using chain.

References:

  1. Constantinos Daskalakis, Andrew Ilyas, Vasilis Syrgkanis, Haoyang Zeng. Training GANs with Optimism. ICLR 2018. OpenReview. ArXiv.
carlosgmartin commented 2 days ago

Below is a demonstration:

```python3 import argparse import jax import optax from jax import lax, numpy as jnp from matplotlib import pyplot as plt, rcParams def optimistic_sgd(learning_rate, strength): return optax.scale_by_optimistic_gradient(-learning_rate, -strength) def optimistic_adam(learning_rate, strength): return optax.chain( optax.scale_by_adam(), optax.scale_by_optimistic_gradient(-learning_rate, -strength), ) def optimistic_adam_wrong_order(learning_rate, strength): return optax.chain( optax.scale_by_optimistic_gradient(-learning_rate, -strength), optax.scale_by_adam(), ) def bilinear_utility_fn(params): """Bilinear saddle point. Has a unique Nash equilibrium at the origin.""" x, y = params z = x * y return jnp.stack([z, -z]) def dirac_gan_utility_fn(params): """Dirac GAN: https://arxiv.org/abs/1801.04406. Has a unique Nash equilibrium at the origin.""" x, y = params z = jnp.logaddexp(0, x * y) return jnp.stack([z, -z]) def parse_args(): p = argparse.ArgumentParser() p.add_argument("--game", type=str, default="bilinear") p.add_argument("--lr", type=float, default=1e-3) p.add_argument("--iters", type=int, default=10**5) p.add_argument("--strength", type=float, default=1e-1) return p.parse_args() def main(): args = parse_args() match args.game: case "bilinear": utility_fn = bilinear_utility_fn case "dirac_gan": utility_fn = dirac_gan_utility_fn case _: raise NotImplementedError(args.game) def update(state, _): params, opt_state = state jac = jax.jacobian(utility_fn)(params) grads = jax.tree.map(jnp.diag, jac) updates, opt_state = opt.update(grads, opt_state, params) params = optax.apply_updates(params, updates) return (params, opt_state), params _, ax_distances = plt.subplots() _, ax_params = plt.subplots() params = jnp.array([1.0, 2.0]) for label, opt in [ ("SGD", optax.sgd(args.lr)), ("Adam", optax.adam(args.lr)), ("Optimistic SGD", optimistic_sgd(args.lr, args.strength)), ("Optimistic Adam", optimistic_adam(args.lr, args.strength)), ]: opt_state = opt.init(params) _, params_hist = lax.scan( update, (params, opt_state), length=args.iters ) distances_to_origin = jnp.hypot(*params_hist.T) ax_params.plot(*params_hist.T, label=label, lw=1) ax_distances.plot(distances_to_origin, label=label, lw=1) ax_params.legend() ax_distances.legend() ax_params.set(title="parameters") ax_distances.set(xlabel="iteration", ylabel="distance to origin") rcParams["savefig.dpi"] = 300 plt.show() if __name__ == "__main__": main() ```

Outputs for --game=bilinear:

Outputs for --game=dirac_gan:

I can submit a PR to create an optax.optimistic_adam function.

fabianp commented 2 days ago

this is great @carlosgmartin !

Would you be willing to contribute such example to the example gallery (https://optax.readthedocs.io/en/latest/gallery.html)? I think this would be very valuable even if there's the somewhat related https://optax.readthedocs.io/en/latest/_collections/examples/ogda_example.html , but I think both examples could be complementary. What do you think?

I would also be OK with adding the solver optimistic_adam to optax (although that would require a bit of work on docstring + tests for this solver)

carlosgmartin commented 19 hours ago

@fabianp Done: #1089.