google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.78k stars 2.72k forks source link

Help with implementing Adam#Lion optimizer #15294

Open buttercutter opened 1 year ago

buttercutter commented 1 year ago

See the last few code boxes of https://colab.research.google.com/drive/1U_qMlcQfD1Dxe-_V9cHpNOA2iF8X1jYg#scrollTo=emIvHzxwQtzj&line=21&uniqifier=1

Adam#Lion optimizer coding

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import optax

def adam_on_lion(x0, y0, lr, max_iter, adam_betas=(0.9, 0.999), lion_betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
    x = jnp.array([x0, y0], dtype=float)
    traj = [x.copy()]

    # Define the Adam and Lion optimizers
    adam = optax.adam(learning_rate=lr, b1=adam_betas[0], b2=adam_betas[1], eps=eps)
    lion = optax.lion(learning_rate=lr, b1=lion_betas[0], b2=lion_betas[1], weight_decay=weight_decay)

    # Define the optimizer chain and update function
    optimizer = optax.chain(lion, adam)
    lion_state, adam_state = lion.init(x), adam.init(x)

    for i in range(max_iter):
        grad = rosenbrock(x) + weight_decay*x

        params = (x[0], x[1])
        #params = jax.tree_util.tree_multimap(lambda x, y: jnp.array([x, y]), *params)
        params = jax.tree_map(lambda x: jnp.array(x), params)
        print("params = ", params)

        # Apply the Lion and Adam steps
        lion_step, lion_state = lion.update(grad, state=lion_state, params=params)
        adam_step, adam_state = adam.update(grad, state=adam_state, params=params)

        # Grafting adam#lion: update direction from lion, update magnitude from adam
        step = jnp.linalg.norm(adam_step) * lion_step / jnp.linalg.norm(lion_step)

        # Update the parameters using the optimizer chain
        x, optimizer_state = optimizer.update(x, grad)
        traj.append(x.copy())

    return jnp.array(traj)

Error traceback:

params =  (DeviceArray(-2., dtype=float32), DeviceArray(2., dtype=float32))

---------------------------------------------------------------------------

TypeError                                 Traceback (most recent call last)

[<ipython-input-10-9e9f791ca8e1>](https://localhost:8080/#) in <module>
----> 1 graft_traj = adam_on_lion(x0=-2, y0=2, lr=0.5, max_iter=250, weight_decay=0)
      2 
      3 fig = plt.figure(figsize=(4,4), dpi=300)
      4 fig.patch.set_facecolor("white")
      5 plot_traj(graft_traj, "Adam#Lion (lr=0.5)")

7 frames

[<ipython-input-9-f1e386ebc5f4>](https://localhost:8080/#) in adam_on_lion(x0, y0, lr, max_iter, adam_betas, lion_betas, eps, weight_decay)
     26 
     27         # Apply the Lion and Adam steps
---> 28         lion_step, lion_state = lion.update(grad, state=lion_state, params=params)
     29         adam_step, adam_state = adam.update(grad, state=adam_state, params=params)
     30 

[/usr/local/lib/python3.9/dist-packages/optax/_src/combine.py](https://localhost:8080/#) in update_fn(updates, state, params, **extra_args)
     57     new_state = []
     58     for s, fn in zip(state, update_fns):
---> 59       updates, new_s = fn(updates, s, params, **extra_args)
     60       new_state.append(new_s)
     61     return updates, tuple(new_state)

[/usr/local/lib/python3.9/dist-packages/optax/_src/base.py](https://localhost:8080/#) in update(***failed resolving arguments***)
    309   def update(updates, state, params=None, **extra_args):
    310     del extra_args
--> 311     return tx.update(updates, state, params)
    312 
    313   return GradientTransformationExtraArgs(tx.init, update)

[/usr/local/lib/python3.9/dist-packages/optax/_src/transform.py](https://localhost:8080/#) in update_fn(updates, state, params)
    767     if params is None:
    768       raise ValueError(base.NO_PARAMS_MSG)
--> 769     updates = jax.tree_util.tree_map(
    770         lambda g, p: g + weight_decay * p, updates, params)
    771     return updates, state

[/usr/local/lib/python3.9/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in tree_map(f, tree, is_leaf, *rest)
    205   leaves, treedef = tree_flatten(tree, is_leaf)
    206   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 207   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    208 
    209 def build_tree(treedef: PyTreeDef, xs: Any) -> Any:

[/usr/local/lib/python3.9/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in <genexpr>(.0)
    205   leaves, treedef = tree_flatten(tree, is_leaf)
    206   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 207   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
    208 
    209 def build_tree(treedef: PyTreeDef, xs: Any) -> Any:

[/usr/local/lib/python3.9/dist-packages/optax/_src/transform.py](https://localhost:8080/#) in <lambda>(g, p)
    768       raise ValueError(base.NO_PARAMS_MSG)
    769     updates = jax.tree_util.tree_map(
--> 770         lambda g, p: g + weight_decay * p, updates, params)
    771     return updates, state
    772 

[/usr/local/lib/python3.9/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in deferring_binary_op(self, other)
   4936       return binary_op(*args)
   4937     if isinstance(other, _rejected_binop_types):
-> 4938       raise TypeError(f"unsupported operand type(s) for {opchar}: "
   4939                       f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")
   4940     return NotImplemented

TypeError: unsupported operand type(s) for +: 'DeviceArray' and 'tuple'
terafo commented 1 year ago

I presume lion.init() function returns tuple. And then you try to add that tuple to an array, which raises type error.

buttercutter commented 1 year ago

@terafo how do you suggest to properly modify the code in this case ?