RobertTLange / evosax

Evolution Strategies in JAX 🦎
Apache License 2.0
464 stars 42 forks source link

`BatchStrategy` for simultaneous subpopulation `ask`/`tell`/`initialize` #6

Open RobertTLange opened 2 years ago

RobertTLange commented 2 years ago

I would like to add a new (abstract) strategy class wrapper(BatchStrategy), which takes a single strategy as input or instantiates one and then performance batched versions of ask, tell and initialize. This provides functionality for executing multiple sub-populations (with same population size) simultaneously in a vectorized/device-parallel fashion.

The high-level brainstorming mindmap looks as follows:

IMG_0941

The workload can be roughly be divided into 3 blocks:

Let's quickly sketch a rough design idea for the first part:

class BatchStrategy(object):
  def __init__(self, num_dims, popsize, strategy_name, subpopulations):
    self.popsize_per_subpop = int(popsize/subpopulations)
    self.strategy = ...  # set up base strategy functionalities
    # Setup map fct based on availability of devices etc. -> see problem rollouts

  def initialize(self, rng, params):
    batch_rng = jax.random.split(rng, self.subpopulations)
    state = jax.vmap(self.strategy.intialize, ...)(batch_rng, params)
    return state

  def ask(self, rng, state, params):
    batch_rng = jax.random.split(rng, self.subpopulations)
    batch_x, state =  jax.vmap(self.strategy.ask, ...)(batch_rng, state, params)
    # Flatten subpopulation proposals back into flat vector
    # batch_x -> Shape: (subpops, popsize_per_subpop, num_dims)
    # x -> Shape: (popsize, num_dims)
    return x, state

  def tell(self, state, params):
    # Reshape flat fitness/search vector into subpopulation array then tell
    # batch_fitness -> Shape: (subpops, popsize_per_subpop)
    # batch_x -> Shape: (subpops, popsize_per_subpop, num_dims)
    state = jax.vmap(self.strategy.tell, ...)(batch_x, batch_fitness, state, params)
    return state

CC @DiamonDiva

RobertTLange commented 2 years ago

@DiamonDiva - The basic building blocks are now implemented in the subpops subdirectory. But there are still a couple of open tasks: