RobertTLange / evosax

Evolution Strategies in JAX 🦎
Apache License 2.0
475 stars 44 forks source link

Evaluating a population of batched environments with CMA-ES #27

Closed nhansendev closed 1 year ago

nhansendev commented 1 year ago

I'm trying to implement a simple Evaluator class to handle the rollouts of batched environments to a population of CMA-ES MLP networks.

Each environment contains a batch of episodes that can be stepped-through in parallel by one network. Each network is paired with one batched environment, creating a population of batched environments to be iterated through.

I've tried implementing this using the brax control example as a reference:

self.rollout_repeats = jax.vmap(self.network, in_axes=(0, None))
self.rollout_pop = jax.vmap(self.rollout_repeats, in_axes=(None, map_dict))

Where map_dict is provided by param_reshaper.vmap_dict

The function is called once each step as: action = self.rollout_pop(jnp.stack(state), policy_params, rng=rng_net)

The state array provided to self.rollout_pop has the shape population_qty, environment_batch, features. The intent is that the individual networks do not iterate over all of the data, just their environment_batch, features. This would produce an output of shape population_qty, environment_batch, action_dim.

So far I have just received various errors and assertions, even when trying to simplify it with a non-batched environment. Please let me know what's wrong with the vmaps. Are they even appropriate for this task? The proper usage of the map_dict is definitely a source of confusion here.

nhansendev commented 1 year ago

Nevermind, I have been able to get a working example of the batching behavior using a nn.Dense layer. Initializing the MLP network is giving me trouble, but this is clearly a generic flax question, not related to evosax.