RobertTLange / evosax

Evolution Strategies in JAX 🦎
Apache License 2.0
475 stars 44 forks source link

PGPE bug? #68

Closed cornelius-braun closed 4 months ago

cornelius-braun commented 4 months ago

Hello,

First, thank you for this very useful and valuable repository!

Now, I realised that when using PGPE, I run into issues for some population size and elite ration combinations. To verify I just reinstalled version 0.1.6 and ran the following minimal example:

from evosax import PGPE

rng = jax.random.PRNGKey(seed=0)
strategy = PGPE(100, 200,
                elite_ratio=0.1, opt_name="adam", lrate_init=.1, sigma_init=.1)
es_params = strategy.default_params
state = strategy.initialize(rng, es_params)

rng, rng_ask, rng_tell = jax.random.split(rng, 3)
x, state = strategy.ask(rng_ask, state, es_params)
fitness = jax.random.uniform(rng_tell, (x.shape[0],))
state = strategy.tell(x, fitness, state, es_params)

The error message that I receive in this case is:

Traceback (most recent call last):
  File "main.py", line 34, in <module>
    state = strategy.tell(x, fitness, state, es_params)
  File "/.venv/lib/python3.10/site-packages/evosax/strategy.py", line 133, in tell
    state = self.tell_strategy(x, fitness_re, state, params)
  File "/.venv/lib/python3.10/site-packages/evosax/strategies/pgpe.py", line 150, in tell_strategy
    (jnp.expand_dims(all_avg_scores, axis=1) - baseline)
  File "/.venv/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 743, in op
    return getattr(self.aval, f"_{name}")(self, *args)
  File "/.venv/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 271, in deferring_binary_op
    return binary_op(*args)
  File "/.venv/lib/python3.10/site-packages/jax/_src/numpy/ufuncs.py", line 99, in fn
    return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)
TypeError: mul got incompatible shapes for broadcasting: (5, 1), (50, 200).

As far as I can see, the issue arises because jnp.expand_dims(all_avg_scores, axis=1) - baseline will be of shape (elite_popsize), while noise_1**2 - jnp.expand_dims(state.sigma**2, axis=0) is of shape (popsize / 2, num_dims). I only briefly looked into the paper by Sehnke, but might it be that noise_1[elite_idx] must be used instead?

RobertTLange commented 4 months ago

Dear @cornelius-braun, Thank you so much for raising this -- really sorry, I apparently only tested it for elite_size = 1.0, which obviously works despite the bug ;) I fixed it and you can reinstall the latest version from the github repo using

pip install git+https://github.com/RobertTLange/evosax.git@main

Please excuse the inconvenience! Cheers, Rob

P.S.: The fix will be released in the next version - Thanks, again.

cornelius-braun commented 4 months ago

Thank you very much for your quick action! :pray: