Closed pharringtonp19 closed 2 years ago
Hi @pharringtonp19,
Thank you for reporting this and giving evosax
a try! As far as I can understand the error is resulting from the interaction of jit
with a shape dependent definition of the default parameters. CMA-ES has quite some intricate heuristics for defining such, which depend on the dimensionality of your problem and the selected population size.
You can solve your problem by instantiating the strategy and hyperparameters outside of your jitted function call. E.g. via
strategy = CMA_ES(popsize=20, num_dims=1, elite_ratio=0.5)
es_params = strategy.default_params
@partial(jax.jit, static_argnums=(1))
def train(key_num, n_epochs):
rng, init_state_rng = jax.random.split(jax.random.PRNGKey(key_num))
state = strategy.initialize(init_state_rng, es_params)
def update(carry, t):
rng, state = carry
rng, rng_gen = jax.random.split(rng, 2)
x, state = strategy.ask(rng_gen, state, es_params)
fitness = jax.vmap(f)(x)
state = strategy.tell(jnp.expand_dims(x,1), fitness, state, es_params)
return (rng, state), (jnp.min(fitness), jnp.mean(fitness), state['best_member'])
(_, state), results = jax.lax.scan(update, (rng, state), None, length=n_epochs)
return state, results
final_state, results = train(0, 1000)
Luckily, this issue for now only appears for CMA-ES. I am not sure if there is a better way to handle this and am open to proposals.
Again, thank you 🤗 Rob
@RobertTLange Thanks for the help!
I was looking to compare with CMA_ES with Differential_ES in this notebook.
When I run CMA_ES using a
jax.lax.scan
function as my training loop I get the following -->I don't get this error when I run CMA_ES in a standard for loop.