RobertTLange / evosax

Evolution Strategies in JAX 🦎
Apache License 2.0
475 stars 44 forks source link

Provide a way to initialize the optimization parameters #47

Closed keraJLi closed 1 year ago

keraJLi commented 1 year ago

Currently there does not seem to be a way to intialize the optimization parameters other than setting the mean of the EvoState obtained from strategy.initialize. This means you also have to care about flattening parameters, etc.

RobertTLange commented 1 year ago

Good point. How about solving this problem via:

    @partial(jax.jit, static_argnums=(0,))
    def initialize(
        self,
        rng: chex.PRNGKey,
        params: Optional[EvoParams] = None,
        init_mean: Optional[Union[chex.Array, chex.ArrayTree]] = None,
    ) -> EvoState:
        """`initialize` the evolution strategy."""
        # Use default hyperparameters if no other settings provided
        if params is None:
            params = self.default_params

        # Initialize strategy based on strategy-specific initialize method
        state = self.initialize_strategy(rng, params)

        if init_mean is not None:
            if self.use_param_reshaper:
                init_mean = self.param_reshaper.flatten_single(init_mean)
            else:
                init_mean = jnp.asarray(init_mean)
            state = state.replace(mean=init_mean)
        return state

I will put together a new release in the coming days.

RobertTLange commented 1 year ago

Addressed and published in release v.0.1.4. See

https://github.com/RobertTLange/evosax/blob/f0f4e058ac1ca03987d3f5311e345721fa78aa82/evosax/strategy.py#L68-L89