RobertTLange / evosax

Evolution Strategies in JAX 🦎
Apache License 2.0
475 stars 44 forks source link

[Feature request] Integration with optax #51

Open carlosgmartin opened 1 year ago

carlosgmartin commented 1 year ago

optax is the most popular JAX library for optimizers. Feature request: Let users pass an optax.GradientTransformation to strategy constructors, rather than a combination of opt_name, lrate_init, etc. This has a few advantages:

Thank you for creating this library. Look forward to hearing your thoughts.

RobertTLange commented 1 year ago

Hi @carlosgmartin -- thank you for raising this and please excuse the late response! I have actually been thinking about this for a second. For now, we have opted for the arguably awkward re-implementation of most common optimizers. This had a couple of reasons:

In general, I am very open to switching to optax in the long run, especially since I am also personally interested in trying out some of the learned gradient-based optimizers in the context of ES (e.g. LION or VeLO), which already all have optax support. Furthermore, it may make population-parallelism (similar to data-parallelism in multi-device settings) a lot more smooth. I have started playing around with some of these ideas for OpenAI-ES here.

Let me know if you would be interested in supporting me in this endeavor. My bucket list right now is pretty large -- so at the moment I can't guarantee that this will happen within the next week(s). Cheers and again thank you, Rob

carlosgmartin commented 1 year ago

@RobertTLange Perhaps a good first step could be to reimplement https://github.com/RobertTLange/evosax/blob/main/evosax/core/optimizer.py internally in terms of optax (perhaps via an OptaxWrapper class), and gradually change the API "outwards" from there. What do you think?

RobertTLange commented 1 year ago

That indeed sounds like a great start ;)

The question is how we ultimately want to parse the optax optimizer to the finite-difference style strategies (OpenAI-ES, PGPE, ARS, ASEBO, etc.). Should this be a string as done right now? Or whether we directly give a optimizer gradient transform fn at strategy initialization. Maybe for the start we can support both?

carlosgmartin commented 1 year ago

It seems to me that, at least in the long run, the optax optimizer should be passed in directly. I see a few advantages to this approach:

What do you think?

carlosgmartin commented 1 year ago

@RobertTLange Here's an example of the simplified interface and implementation I have in mind:

import argparse
import sys

import evosax
import jax
import optax
from jax import lax, numpy as jnp, random
from jax.flatten_util import ravel_pytree

class OpenAIES:
    def __init__(self, loss_fn, scale, batch_size, optimizer):
        assert batch_size % 2 == 0, "batch_size must be even"
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.pairs = batch_size // 2
        self.scale = scale

    def init(self, params, key):
        return params, self.optimizer.init(params)

    def grads(self, params, key):
        x, unravel = ravel_pytree(params)
        keys = random.split(key, 1 + self.pairs)
        z = random.normal(keys[0], [self.pairs, x.size])
        u = z * self.scale
        yp = jax.vmap(self.loss_fn)(jax.vmap(unravel)(x + u), keys[1:])
        ym = jax.vmap(self.loss_fn)(jax.vmap(unravel)(x - u), keys[1:])
        g = (yp - ym) @ z / (len(z) * 2 * self.scale)
        grads = unravel(g)
        return grads

    def step(self, state, key):
        params, opt_state = state
        grads = self.grads(params, key)
        updates, opt_state = self.optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return (params, opt_state), None

    def params(self, state):
        params, opt_state = state
        return params

class EvosaxWrapper:
    def __init__(self, loss_fn, strategy):
        self.loss_fn = loss_fn
        self.strategy = strategy

    def init(self, params, key):
        return self.strategy.initialize(key, init_mean=params)

    def step(self, state, key):
        keys = jax.random.split(key, 1 + self.strategy.popsize)
        x, state = self.strategy.ask(keys[0], state)
        # jnp.concatenate([keys[1:], keys[1:]])
        y = jax.vmap(self.loss_fn)(x, keys[1:])
        n_state = self.strategy.tell(x, y.astype(float), state)
        return n_state, None

    def params(self, state):
        return self.strategy.param_reshaper.reshape_single(state.mean)

def parse_args(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--scale", type=float, default=1e-1)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--lr", type=float, default=1e-2)
    parser.add_argument("--steps", type=int, default=10**4)
    return parser.parse_args(argv)

def main(argv):
    args = parse_args(argv)

    def loss_fn(params, key):
        return params["a"] @ params["a"] + random.normal(key) * 0.1

    params = {"a": jnp.ones(4)}
    key = random.PRNGKey(args.seed)

    for strategy in [
        EvosaxWrapper(
            loss_fn,
            evosax.OpenES(
                popsize=args.batch_size,
                pholder_params=params,
                opt_name="adam",
                lrate_init=args.lr,
                lrate_limit=0,
                sigma_init=args.scale,
                sigma_limit=0,
            ),
        ),
        OpenAIES(loss_fn, args.scale, args.batch_size, optax.adam(args.lr)),
    ]:
        keys = random.split(key, 1 + args.steps)
        state = strategy.init(params, keys[0])
        state, _ = lax.scan(strategy.step, state, keys[1:])
        print(strategy.params(state)["a"])

if __name__ == "__main__":
    main(sys.argv[1:])

Output:

ParameterReshaper: 4 parameters detected for optimization.
[ 0.01049289 -0.00886521 -0.00689528  0.0649689 ]
[-1.65249183e-08  9.41156397e-09  2.47366607e-08  1.30447395e-08]